├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── fastspeech2.png
├── model.txt
├── tensorboard1.png
├── tensorboard1_1.png
├── tensorboard2.png
└── tensorboard2_1.png
├── compute_statistics.py
├── configs
└── default.yaml
├── core
├── __init__.py
├── attention.py
├── duration_modeling
│ ├── __init__.py
│ ├── duration_predictor.py
│ └── length_regulator.py
├── embedding.py
├── encoder.py
├── modules.py
├── optimizer.py
└── variance_predictor.py
├── dataset
├── __init__.py
├── audio_processing.py
├── dataloader.py
├── ljspeech.py
└── texts
│ ├── __init__.py
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── dict_.py
│ ├── numbers.py
│ └── symbols.py
├── demo_fastspeech2.ipynb
├── evaluation.py
├── export_torchscript.py
├── fastspeech.py
├── filelists
├── train_filelist.txt
└── valid_filelist.txt
├── inference.py
├── nvidia_preprocessing.py
├── requirements.txt
├── sample
├── generated_mel_58k.npy
├── sample2_58k.wav
├── sample_102k_melgan.wav
├── sample_102k_waveglow.wav
├── sample_58k.wav
├── sample_74k_melgan.wav
└── sample_74k_waveglow.wav
├── tests
├── __init__.py
└── test_fastspeech2.py
├── train_fastspeech.py
└── utils
├── __init__.py
├── display.py
├── fastspeech2_script.py
├── hparams.py
├── plot.py
├── stft.py
└── util.py
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python application
5 |
6 | on:
7 | push:
8 | branches: [ master ]
9 | pull_request:
10 | branches: [ master ]
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 |
17 | steps:
18 | - uses: actions/checkout@v2
19 | - name: Set up Python 3.6
20 | uses: actions/setup-python@v2
21 | with:
22 | python-version: 3.6
23 | - name: Install dependencies
24 | run: |
25 | sudo apt install libsndfile1
26 | python -m pip install --upgrade pip
27 | pip install flake8 pytest
28 | pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 | - name: Lint with flake8
31 | run: |
32 | # stop the build if there are Python syntax errors or undefined names
33 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
34 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
35 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
36 | - name: Test with pytest
37 | run: |
38 | pytest
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | idea/*
3 | /.idea
4 | /data
5 | /output
6 | /logs
7 | /__pycache__
8 | /core/__pycache__
9 | /core/duration_modeling/__pycache__
10 | /core/energy_predictor/__pycache__
11 | /core/pitch_predictor/__pycache__
12 | /dataset/__pycache__
13 | /dataset/texts/__pycache__
14 | /utils/__pycache__
15 | /checkpoints
16 | /trace_loss.txt
17 | /unused_code.txt
18 | /test.py
19 | /rest_tts.py
20 | /preprocess.py
21 | /trace_loss_nvidia.txt
22 | /conf
23 | /etc
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fastspeech 2
2 | UnOfficial PyTorch implementation of [**FastSpeech 2: Fast and High-Quality End-to-End Text to Speech**](https://arxiv.org/abs/2006.04558). This repo uses the FastSpeech implementation of [Espnet](https://github.com/espnet/espnet) as a base. In this implementation I tried to replicate the exact paper details but still some modification required for better model, this repo open for any suggestion and improvement. This repo uses Nvidia's tacotron 2 preprocessing for audio pre-processing and [MelGAN](https://github.com/seungwonpark/melgan) as vocoder.
3 |
4 |
5 | 
6 |
7 | ## Demo : [](https://colab.research.google.com/github/rishikksh20/FastSpeech2/blob/master/demo_fastspeech2.ipynb)
8 |
9 | ## Requirements :
10 | All code written in `Python 3.6.2` .
11 | * Install Pytorch
12 | > Before installing pytorch please check your Cuda version by running following command :
13 | `nvcc --version`
14 | ```
15 | pip install torch torchvision
16 | ```
17 | In this repo I have used Pytorch 1.6.0 for `torch.bucketize` feature which is not present in previous versions of PyTorch.
18 |
19 |
20 | * Installing other requirements :
21 | ```
22 | pip install -r requirements.txt
23 | ```
24 |
25 | * To use Tensorboard install `tensorboard version 1.14.0` seperatly with supported `tensorflow (1.14.0)`
26 |
27 |
28 |
29 | ## For Preprocessing :
30 |
31 | `filelists` folder contains MFA (Motreal Force aligner) processed LJSpeech dataset files so you don't need to align text with audio (for extract duration) for LJSpeech dataset.
32 | For other dataset follow instruction [here](https://github.com/ivanvovk/DurIAN#6-how-to-align-your-own-data). For other pre-processing run following command :
33 | ```
34 | python .\nvidia_preprocessing.py -d path_of_wavs
35 | ```
36 | For finding the min and max of F0 and Energy
37 | ```buildoutcfg
38 | python .\compute_statistics.py
39 | ```
40 | Update the following in `hparams.py` by min and max of F0 and Energy
41 | ```
42 | p_min = Min F0/pitch
43 | p_max = Max F0
44 | e_min = Min energy
45 | e_max = Max energy
46 | ```
47 |
48 | ## For training
49 | ```
50 | python train_fastspeech.py --outdir etc -c configs/default.yaml -n "name"
51 | ```
52 |
53 | ## For inference
54 | [](https://colab.research.google.com/github/rishikksh20/FastSpeech2/blob/master/demo_fastspeech2.ipynb)
55 | Currently only phonemes based Synthesis supported.
56 | ```
57 | python .\inference.py -c .\configs\default.yaml -p .\checkpoints\first_1\ts_version2_fastspeech_fe9a2c7_7k_steps.pyt --out output --text "ModuleList can be indexed like a regular Python list but modules it contains are properly registered."
58 | ```
59 | ## For TorchScript Export
60 | ```commandline
61 | python export_torchscript.py -c configs/default.yaml -n fastspeech_scrip --outdir etc
62 | ```
63 | ## Checkpoint and samples:
64 | * Checkpoint find [here](https://drive.google.com/drive/folders/1Fh7zr8zoTydNpD6hTNBPKUGN_s93Bqrs?usp=sharing)
65 | * For samples check `sample` folder.
66 |
67 | ## Tensorboard
68 |
69 | **Training :**
70 | 
71 | **Validation :**
72 | 
73 | ## Note
74 | * Coding of this repo is roughly done just to re-produce the paper and experimentation purpose. Needed a code cleanup and opyimization for better use.
75 | * Currently this repo produces good quality audio but still it is in WIP, many improvement needed.
76 | * Loss curve for F0 is quite high.
77 | * I am using raw F0 and energy for train a model, but we can also use normalize F0 and energy for stable training.
78 | * Using `Postnet` for better audio quality.
79 | * For more complete and end to end Voice cloning or Text to Speech (TTS) toolbox ⚡ please visit [Deepsync Technologies](https://deepsync.co/).
80 |
81 | ## References
82 | - [FastSpeech 2: Fast and High-Quality End-to-End Text to Speech](https://arxiv.org/abs/2006.04558)
83 | - [FastSpeech: Fast, Robust and Controllable Text to Speech](https://arxiv.org/abs/1905.09263)
84 | - [ESPnet](https://github.com/espnet/espnet)
85 | - [NVIDIA's WaveGlow implementation](https://github.com/NVIDIA/waveglow)
86 | - [MelGAN](https://github.com/seungwonpark/melgan)
87 | - [DurIAN](https://github.com/ivanvovk/DurIAN)
88 | - [FastSpeech2 Tensorflow Implementation](https://github.com/TensorSpeech/TensorflowTTS)
89 | - [Other PyTorch FastSpeech 2 Implementation](https://github.com/ming024/FastSpeech2)
90 | - [WaveRNN](https://github.com/fatchord/WaveRNN)
91 |
--------------------------------------------------------------------------------
/assets/fastspeech2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/assets/fastspeech2.png
--------------------------------------------------------------------------------
/assets/model.txt:
--------------------------------------------------------------------------------
1 | Batch Size : 16
2 | Trainable Parameters: 25.637M
3 | FeedForwardTransformer(
4 | (encoder): Encoder(
5 | (embed): Sequential(
6 | (0): Embedding(56, 256, padding_idx=0)
7 | (1): ScaledPositionalEncoding(
8 | (dropout): Dropout(p=0.2, inplace=False)
9 | )
10 | )
11 | (encoders): MultiSequential(
12 | (0): EncoderLayer(
13 | (self_attn): MultiHeadedAttention(
14 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
15 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
16 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
17 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
18 | (dropout): Dropout(p=0.2, inplace=False)
19 | )
20 | (feed_forward): MultiLayeredConv1d(
21 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
22 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
23 | (dropout): Dropout(p=0.2, inplace=False)
24 | )
25 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
26 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
27 | (dropout): Dropout(p=0.2, inplace=False)
28 | )
29 | (1): EncoderLayer(
30 | (self_attn): MultiHeadedAttention(
31 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
32 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
33 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
34 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
35 | (dropout): Dropout(p=0.2, inplace=False)
36 | )
37 | (feed_forward): MultiLayeredConv1d(
38 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
39 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
40 | (dropout): Dropout(p=0.2, inplace=False)
41 | )
42 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
43 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
44 | (dropout): Dropout(p=0.2, inplace=False)
45 | )
46 | (2): EncoderLayer(
47 | (self_attn): MultiHeadedAttention(
48 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
49 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
50 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
51 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
52 | (dropout): Dropout(p=0.2, inplace=False)
53 | )
54 | (feed_forward): MultiLayeredConv1d(
55 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
56 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
57 | (dropout): Dropout(p=0.2, inplace=False)
58 | )
59 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
60 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
61 | (dropout): Dropout(p=0.2, inplace=False)
62 | )
63 | (3): EncoderLayer(
64 | (self_attn): MultiHeadedAttention(
65 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
66 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
67 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
68 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
69 | (dropout): Dropout(p=0.2, inplace=False)
70 | )
71 | (feed_forward): MultiLayeredConv1d(
72 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
73 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
74 | (dropout): Dropout(p=0.2, inplace=False)
75 | )
76 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
77 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
78 | (dropout): Dropout(p=0.2, inplace=False)
79 | )
80 | )
81 | )
82 | (duration_predictor): DurationPredictor(
83 | (conv): ModuleList(
84 | (0): Sequential(
85 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
86 | (1): ReLU()
87 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
88 | (3): Dropout(p=0.5, inplace=False)
89 | )
90 | (1): Sequential(
91 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
92 | (1): ReLU()
93 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
94 | (3): Dropout(p=0.5, inplace=False)
95 | )
96 | )
97 | (linear): Linear(in_features=256, out_features=1, bias=True)
98 | )
99 | (energy_predictor): EnergyPredictor(
100 | (predictor): VariancePredictor(
101 | (conv): ModuleList(
102 | (0): Sequential(
103 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
104 | (1): ReLU()
105 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
106 | (3): Dropout(p=0.5, inplace=False)
107 | )
108 | (1): Sequential(
109 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
110 | (1): ReLU()
111 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
112 | (3): Dropout(p=0.5, inplace=False)
113 | )
114 | )
115 | (linear): Linear(in_features=256, out_features=1, bias=True)
116 | )
117 | )
118 | (energy_embed): Linear(in_features=256, out_features=256, bias=True)
119 | (pitch_predictor): PitchPredictor(
120 | (predictor): VariancePredictor(
121 | (conv): ModuleList(
122 | (0): Sequential(
123 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
124 | (1): ReLU()
125 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
126 | (3): Dropout(p=0.5, inplace=False)
127 | )
128 | (1): Sequential(
129 | (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
130 | (1): ReLU()
131 | (2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
132 | (3): Dropout(p=0.5, inplace=False)
133 | )
134 | )
135 | (linear): Linear(in_features=256, out_features=1, bias=True)
136 | )
137 | )
138 | (pitch_embed): Linear(in_features=256, out_features=256, bias=True)
139 | (length_regulator): LengthRegulator()
140 | (decoder): Encoder(
141 | (embed): Sequential(
142 | (0): ScaledPositionalEncoding(
143 | (dropout): Dropout(p=0.2, inplace=False)
144 | )
145 | )
146 | (encoders): MultiSequential(
147 | (0): EncoderLayer(
148 | (self_attn): MultiHeadedAttention(
149 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
150 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
151 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
152 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
153 | (dropout): Dropout(p=0.2, inplace=False)
154 | )
155 | (feed_forward): MultiLayeredConv1d(
156 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
157 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
158 | (dropout): Dropout(p=0.2, inplace=False)
159 | )
160 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
161 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
162 | (dropout): Dropout(p=0.2, inplace=False)
163 | )
164 | (1): EncoderLayer(
165 | (self_attn): MultiHeadedAttention(
166 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
167 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
168 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
169 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
170 | (dropout): Dropout(p=0.2, inplace=False)
171 | )
172 | (feed_forward): MultiLayeredConv1d(
173 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
174 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
175 | (dropout): Dropout(p=0.2, inplace=False)
176 | )
177 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
178 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
179 | (dropout): Dropout(p=0.2, inplace=False)
180 | )
181 | (2): EncoderLayer(
182 | (self_attn): MultiHeadedAttention(
183 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
184 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
185 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
186 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
187 | (dropout): Dropout(p=0.2, inplace=False)
188 | )
189 | (feed_forward): MultiLayeredConv1d(
190 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
191 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
192 | (dropout): Dropout(p=0.2, inplace=False)
193 | )
194 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
195 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
196 | (dropout): Dropout(p=0.2, inplace=False)
197 | )
198 | (3): EncoderLayer(
199 | (self_attn): MultiHeadedAttention(
200 | (linear_q): Linear(in_features=256, out_features=256, bias=True)
201 | (linear_k): Linear(in_features=256, out_features=256, bias=True)
202 | (linear_v): Linear(in_features=256, out_features=256, bias=True)
203 | (linear_out): Linear(in_features=256, out_features=256, bias=True)
204 | (dropout): Dropout(p=0.2, inplace=False)
205 | )
206 | (feed_forward): MultiLayeredConv1d(
207 | (w_1): Conv1d(256, 1024, kernel_size=(9,), stride=(1,), padding=(4,))
208 | (w_2): Conv1d(1024, 256, kernel_size=(1,), stride=(1,))
209 | (dropout): Dropout(p=0.2, inplace=False)
210 | )
211 | (norm1): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
212 | (norm2): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
213 | (dropout): Dropout(p=0.2, inplace=False)
214 | )
215 | )
216 | )
217 | (postnet): Postnet(
218 | (postnet): ModuleList(
219 | (0): Sequential(
220 | (0): Conv1d(80, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
221 | (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
222 | (2): Tanh()
223 | (3): Dropout(p=0.5, inplace=False)
224 | )
225 | (1): Sequential(
226 | (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
227 | (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
228 | (2): Tanh()
229 | (3): Dropout(p=0.5, inplace=False)
230 | )
231 | (2): Sequential(
232 | (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
233 | (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
234 | (2): Tanh()
235 | (3): Dropout(p=0.5, inplace=False)
236 | )
237 | (3): Sequential(
238 | (0): Conv1d(256, 256, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
239 | (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
240 | (2): Tanh()
241 | (3): Dropout(p=0.5, inplace=False)
242 | )
243 | (4): Sequential(
244 | (0): Conv1d(256, 80, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)
245 | (1): BatchNorm1d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
246 | (2): Dropout(p=0.5, inplace=False)
247 | )
248 | )
249 | )
250 | (feat_out): Linear(in_features=256, out_features=80, bias=True)
251 | (duration_criterion): DurationPredictorLoss(
252 | (criterion): MSELoss()
253 | )
254 | (energy_criterion): EnergyPredictorLoss(
255 | (criterion): MSELoss()
256 | )
257 | (pitch_criterion): PitchPredictorLoss(
258 | (criterion): MSELoss()
259 | )
260 | (criterion): L1Loss()
261 | )
--------------------------------------------------------------------------------
/assets/tensorboard1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/assets/tensorboard1.png
--------------------------------------------------------------------------------
/assets/tensorboard1_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/assets/tensorboard1_1.png
--------------------------------------------------------------------------------
/assets/tensorboard2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/assets/tensorboard2.png
--------------------------------------------------------------------------------
/assets/tensorboard2_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/assets/tensorboard2_1.png
--------------------------------------------------------------------------------
/compute_statistics.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from utils.util import get_files
4 | from tqdm import tqdm
5 | from utils.util import remove_outlier
6 | from utils.hparams import HParam
7 |
8 | if __name__ == "__main__":
9 |
10 | hp = HParam("./configs/default.yaml")
11 |
12 | min_e = []
13 | min_p = []
14 | max_e = []
15 | max_p = []
16 | nz_min_p = []
17 | nz_min_e = []
18 |
19 | energy_path = os.path.join(hp.data.data_dir, "energy")
20 | pitch_path = os.path.join(hp.data.data_dir, "pitch")
21 | mel_path = os.path.join(hp.data.data_dir, "mels")
22 | energy_files = get_files(energy_path, extension=".npy")
23 | pitch_files = get_files(pitch_path, extension=".npy")
24 | mel_files = get_files(mel_path, extension=".npy")
25 |
26 | assert len(energy_files) == len(pitch_files) == len(mel_files)
27 |
28 | energy_vecs = []
29 | for f in tqdm(energy_files):
30 | e = np.load(f)
31 | e = remove_outlier(e)
32 | energy_vecs.append(e)
33 | min_e.append(e.min())
34 | nz_min_e.append(e[e > 0].min())
35 | max_e.append(e.max())
36 |
37 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in energy_vecs])
38 | e_mean, e_std = np.mean(nonzeros), np.std(nonzeros)
39 | print("Non zero Min Energy : {}".format(min(nz_min_e)))
40 | print("Max Energy : {}".format(max(max_e)))
41 | print("Energy mean : {}".format(e_mean))
42 | print("Energy std: {}".format(e_std))
43 |
44 | pitch_vecs = []
45 | bad_pitch = []
46 | for f in tqdm(pitch_files):
47 | # print(f)
48 | p = np.load(f)
49 | p = remove_outlier(p)
50 | pitch_vecs.append(p)
51 | # print(len(p), "#########", p)
52 | try:
53 | min_p.append(p.min())
54 | nz_min_p.append(p[p > 0].min())
55 | max_p.append(p.max())
56 | except ValueError:
57 | bad_pitch.append(f)
58 |
59 | nonzeros = np.concatenate([v[np.where(v != 0.0)[0]] for v in pitch_vecs])
60 | f0_mean, f0_std = np.mean(nonzeros), np.std(nonzeros)
61 | print("Min Pitch : {}".format(min(min_p)))
62 | print("Non zero Min Pitch : {}".format(min(nz_min_p)))
63 | print("Max Pitch : {}".format(max(max_p)))
64 | print("Pitch mean : {}".format(f0_mean))
65 | print("Pitch std: {}".format(f0_std))
66 |
67 | np.save(
68 | os.path.join(hp.data.data_dir, "e_mean.npy"),
69 | e_mean.astype(np.float32),
70 | allow_pickle=False,
71 | )
72 | np.save(
73 | os.path.join(hp.data.data_dir, "e_std.npy"),
74 | e_std.astype(np.float32),
75 | allow_pickle=False,
76 | )
77 | np.save(
78 | os.path.join(hp.data.data_dir, "f0_mean.npy"),
79 | f0_mean.astype(np.float32),
80 | allow_pickle=False,
81 | )
82 | np.save(
83 | os.path.join(hp.data.data_dir, "f0_std.npy"),
84 | f0_std.astype(np.float32),
85 | allow_pickle=False,
86 | )
87 | print("The len of bad Pitch Vectors is ", len(bad_pitch))
88 | # print(bad_pitch)
89 | with open("bad_file.txt", "a") as f:
90 | for i in bad_pitch:
91 | c = i.split("/")[3].split(".")[0]
92 | f.write(c)
93 | f.write("\n")
94 |
95 | # print("Min Energy : {}".format(min(min_e)))
96 |
97 | # print("Min Pitch : {}".format(min(min_p)))
98 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | data_dir: 'H:\Deepsync\backup\fastspeech\data\'
3 | wav_dir: 'H:\Deepsync\backup\deepsync\LJSpeech-1.1\wavs\'
4 | # Compute statistics
5 | e_mean: 21.578571319580078
6 | e_std: 18.916799545288086
7 | e_min: 0.01786651276051998
8 | e_max: 130.5338592529297
9 |
10 | f0_mean: 206.5135564772342
11 | f0_std: 53.633228905750336
12 | p_min: 71.0
13 | p_max: 676.2260946528305 # 799.8901977539062
14 | train_filelist: "./filelists/train_filelist.txt"
15 | valid_filelist: "./filelists/valid_filelist.txt"
16 | tts_cleaner_names: ['english_cleaners']
17 |
18 | # feature extraction related
19 | audio:
20 | sample_rate: 22050 # sampling frequency
21 | fmax: 8000.0 # maximum frequency
22 | fmin: 0.0 # minimum frequency
23 | n_mels: 80 # number of mel basis
24 | n_fft: 1024 # number of fft points
25 | hop_length: 256 # number of shift points
26 | win_length: 1024 # window length
27 | num_mels : 80
28 | min_level_db : -100
29 | ref_level_db : 20
30 | bits : 9 # bit depth of signal
31 | mu_law : True # Recommended to suppress noise if using raw bits in hp.voc_mode below
32 | peak_norm : False # Normalise to the peak of each wav file
33 |
34 |
35 |
36 |
37 | # network architecture related
38 | model:
39 | embed_dim: 0
40 | eprenet_conv_layers: 0 # one more linear layer w/o non_linear will be added for 0_centor
41 | eprenet_conv_filts: 0
42 | eprenet_conv_chans: 0
43 | dprenet_layers: 2 # one more linear layer w/o non_linear will be added for 0_centor
44 | dprenet_units: 256 # 384
45 | adim: 256
46 | aheads: 2
47 | elayers: 4
48 | eunits: 1024
49 | ddim: 384
50 | dlayers: 4
51 | dunits: 1024
52 | positionwise_layer_type : "conv1d" # linear
53 | positionwise_conv_kernel_size : 9 # 1
54 | postnet_layers: 5
55 | postnet_filts: 5
56 | postnet_chans: 256
57 | use_masking: True
58 | use_weighted_masking: False
59 | bce_pos_weight: 5.0
60 | use_batch_norm: True
61 | use_scaled_pos_enc: True
62 | encoder_normalize_before: False
63 | decoder_normalize_before: False
64 | encoder_concat_after: False
65 | decoder_concat_after: False
66 | reduction_factor: 1
67 | loss_type : "L1"
68 | # minibatch related
69 | batch_sort_key: input # shuffle or input or output
70 | batch_bins: 2549760 # 12 * (870 * 80 + 180 * 35)
71 | # batch_size * (max_out * dim_out + max_in * dim_in)
72 | # resuling in 11 ~ 66 samples (avg 15 samples) in batch (809 batches per epochs) for ljspeech
73 |
74 | # training related
75 | transformer_init: 'pytorch' # choices:["pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"]
76 | transformer_warmup_steps: 4000
77 | transformer_lr: 1.0
78 | initial_encoder_alpha: 1.0
79 | initial_decoder_alpha: 1.0
80 | eprenet_dropout_rate: 0.0
81 | dprenet_dropout_rate: 0.5
82 | postnet_dropout_rate: 0.5
83 | transformer_enc_dropout_rate: 0.1
84 | transformer_enc_positional_dropout_rate: 0.1
85 | transformer_enc_attn_dropout_rate: 0.1
86 | transformer_dec_dropout_rate: 0.1
87 | transformer_dec_positional_dropout_rate: 0.1
88 | transformer_dec_attn_dropout_rate: 0.1
89 | transformer_enc_dec_attn_dropout_rate: 0.1
90 | use_guided_attn_loss: True
91 | num_heads_applied_guided_attn: 2
92 | num_layers_applied_guided_attn: 2
93 | modules_applied_guided_attn: ["encoder_decoder"]
94 | guided_attn_loss_sigma: 0.4
95 | guided_attn_loss_lambda: 1.0
96 |
97 | ### FastSpeech
98 | duration_predictor_layers : 2
99 | duration_predictor_chans : 256
100 | duration_predictor_kernel_size : 3
101 | transfer_encoder_from_teacher : True
102 | duration_predictor_dropout_rate : 0.5
103 | teacher_model : ""
104 | transferred_encoder_module : "all" # choices:["all", "embed"]
105 |
106 | attn_plot : False
107 |
108 |
109 | train:
110 | # optimization related
111 | eos: False #True
112 | opt: 'noam'
113 | accum_grad: 4
114 | grad_clip: 1.0
115 | weight_decay: 0.001
116 | patience: 0
117 | epochs: 1000 # 1,000 epochs * 809 batches / 5 accum_grad : 161,800 iters
118 | save_interval_epoch: 10
119 | GTA : False
120 | # other
121 | ngpu: 1 # number of gpus ("0" uses cpu, otherwise use gpu)
122 | nj: 4 # number of parallel jobs
123 | dumpdir: '' # directory to dump full features
124 | verbose: 0 # verbose option (if set > 0, get more log)
125 | N: 0 # number of minibatches to be used (mainly for debugging). "0" uses all minibatches.
126 | seed: 1 # random seed number
127 | resume: "" # the snapshot path to resume (if set empty, no effect)
128 | use_phonemes: True
129 | batch_size : 16
130 | # other
131 | melgan_vocoder : True
132 | save_interval : 1000
133 | chkpt_dir : './checkpoints'
134 | log_dir : './logs'
135 | summary_interval : 200
136 | validation_step : 500
137 | tts_max_mel_len : 870 # if you have a couple of extremely long spectrograms you might want to use this
138 | tts_bin_lengths : True # bins the spectrogram lengths before sampling in data loader - speeds up training
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/core/__init__.py
--------------------------------------------------------------------------------
/core/attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 |
8 | class MultiHeadedAttention(nn.Module):
9 | """Multi-Head Attention layer
10 |
11 | :param int n_head: the number of head s
12 | :param int n_feat: the number of features
13 | :param float dropout_rate: dropout rate
14 | """
15 |
16 | def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
17 | super(MultiHeadedAttention, self).__init__()
18 | assert n_feat % n_head == 0
19 | # We assume d_v always equals d_k
20 | self.d_k = n_feat // n_head
21 | self.h = n_head
22 | self.linear_q = nn.Linear(n_feat, n_feat)
23 | self.linear_k = nn.Linear(n_feat, n_feat)
24 | self.linear_v = nn.Linear(n_feat, n_feat)
25 | self.linear_out = nn.Linear(n_feat, n_feat)
26 | # self.attn: Optional[torch.Tensor] = None # torch.empty(0)
27 | # self.register_buffer("attn", torch.empty(0))
28 | self.dropout = nn.Dropout(p=dropout_rate)
29 |
30 | def forward(
31 | self,
32 | query: torch.Tensor,
33 | key: torch.Tensor,
34 | value: torch.Tensor,
35 | mask: Optional[torch.Tensor] = None,
36 | ) -> torch.Tensor:
37 | """Compute 'Scaled Dot Product Attention'
38 |
39 | :param torch.Tensor query: (batch, time1, size)
40 | :param torch.Tensor key: (batch, time2, size)
41 | :param torch.Tensor value: (batch, time2, size)
42 | :param torch.Tensor mask: (batch, time1, time2)
43 | :param torch.nn.Dropout dropout:
44 | :return torch.Tensor: attentined and transformed `value` (batch, time1, d_model)
45 | weighted by the query dot key attention (batch, head, time1, time2)
46 | """
47 | n_batch = query.size(0)
48 | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
49 | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
50 | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
51 | q = q.transpose(1, 2) # (batch, head, time1, d_k)
52 | k = k.transpose(1, 2) # (batch, head, time2, d_k)
53 | v = v.transpose(1, 2) # (batch, head, time2, d_k)
54 |
55 | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(
56 | self.d_k
57 | ) # (batch, head, time1, time2)
58 | if mask is not None:
59 | mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
60 | # min_value: float = float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
61 | mask = mask.to(device=scores.device)
62 | scores = scores.masked_fill_(mask, -np.inf)
63 | attn = torch.softmax(scores, dim=-1).masked_fill(
64 | mask, 0.0
65 | ) # (batch, head, time1, time2)
66 | else:
67 | attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
68 |
69 | p_attn = self.dropout(attn)
70 | x = torch.matmul(p_attn, v) # (batch, head, time1, d_k)
71 | x = (
72 | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
73 | ) # (batch, time1, d_model)
74 | return self.linear_out(x) # (batch, time1, d_model)
75 |
--------------------------------------------------------------------------------
/core/duration_modeling/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/core/duration_modeling/__init__.py
--------------------------------------------------------------------------------
/core/duration_modeling/duration_predictor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Tomoki Hayashi
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Duration predictor related loss."""
8 |
9 | import torch
10 | from typing import Optional
11 | from core.modules import LayerNorm
12 |
13 |
14 | class DurationPredictor(torch.nn.Module):
15 | """Duration predictor module.
16 |
17 | This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
18 | The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder.
19 |
20 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
21 | https://arxiv.org/pdf/1905.09263.pdf
22 |
23 | Note:
24 | The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`,
25 | the outputs are calculated in log domain but in `inference`, those are calculated in linear domain.
26 |
27 | """
28 |
29 | def __init__(
30 | self, idim, n_layers=2, n_chans=256, kernel_size=3, dropout_rate=0.1, offset=1.0
31 | ):
32 | """Initilize duration predictor module.
33 |
34 | Args:
35 | idim (int): Input dimension.
36 | n_layers (int, optional): Number of convolutional layers.
37 | n_chans (int, optional): Number of channels of convolutional layers.
38 | kernel_size (int, optional): Kernel size of convolutional layers.
39 | dropout_rate (float, optional): Dropout rate.
40 | offset (float, optional): Offset value to avoid nan in log domain.
41 |
42 | """
43 | super(DurationPredictor, self).__init__()
44 | self.offset = offset
45 | self.conv = torch.nn.ModuleList()
46 | for idx in range(n_layers):
47 | in_chans = idim if idx == 0 else n_chans
48 | self.conv += [
49 | torch.nn.Sequential(
50 | torch.nn.Conv1d(
51 | in_chans,
52 | n_chans,
53 | kernel_size,
54 | stride=1,
55 | padding=(kernel_size - 1) // 2,
56 | ),
57 | torch.nn.ReLU(),
58 | LayerNorm(n_chans),
59 | torch.nn.Dropout(dropout_rate),
60 | )
61 | ]
62 | self.linear = torch.nn.Linear(n_chans, 1)
63 |
64 | def _forward(
65 | self,
66 | xs: torch.Tensor,
67 | x_masks: Optional[torch.Tensor] = None,
68 | is_inference: bool = False,
69 | ):
70 | xs = xs.transpose(1, -1) # (B, idim, Tmax)
71 | for f in self.conv:
72 | xs = f(xs) # (B, C, Tmax)
73 |
74 | # NOTE: calculate in log domain
75 | xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
76 |
77 | if is_inference:
78 | # NOTE: calculate in linear domain
79 | xs = torch.clamp(
80 | torch.round(xs.exp() - self.offset), min=0
81 | ).long() # avoid negative value
82 |
83 | if x_masks is not None:
84 | xs = xs.masked_fill(x_masks, 0.0)
85 |
86 | return xs
87 |
88 | def forward(self, xs: torch.Tensor, x_masks: Optional[torch.Tensor] = None):
89 | """Calculate forward propagation.
90 |
91 | Args:
92 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
93 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
94 |
95 | Returns:
96 | Tensor: Batch of predicted durations in log domain (B, Tmax).
97 |
98 | """
99 | return self._forward(xs, x_masks, False)
100 |
101 | def inference(self, xs, x_masks: Optional[torch.Tensor] = None):
102 | """Inference duration.
103 |
104 | Args:
105 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
106 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
107 |
108 | Returns:
109 | LongTensor: Batch of predicted durations in linear domain (B, Tmax).
110 |
111 | """
112 | return self._forward(xs, x_masks, True)
113 |
114 |
115 | class DurationPredictorLoss(torch.nn.Module):
116 | """Loss function module for duration predictor.
117 |
118 | The loss value is Calculated in log domain to make it Gaussian.
119 |
120 | """
121 |
122 | def __init__(self, offset=1.0):
123 | """Initilize duration predictor loss module.
124 |
125 | Args:
126 | offset (float, optional): Offset value to avoid nan in log domain.
127 |
128 | """
129 | super(DurationPredictorLoss, self).__init__()
130 | self.criterion = torch.nn.MSELoss()
131 | self.offset = offset
132 |
133 | def forward(self, outputs, targets):
134 | """Calculate forward propagation.
135 |
136 | Args:
137 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
138 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
139 |
140 | Returns:
141 | Tensor: Mean squared error loss value.
142 |
143 | Note:
144 | `outputs` is in log domain but `targets` is in linear domain.
145 |
146 | """
147 | # NOTE: outputs is in log domain while targets in linear
148 | targets = torch.log(targets.float() + self.offset)
149 | loss = self.criterion(outputs, targets)
150 |
151 | return loss
152 |
--------------------------------------------------------------------------------
/core/duration_modeling/length_regulator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Tomoki Hayashi
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """Length regulator related loss."""
8 |
9 | import logging
10 |
11 | import torch
12 |
13 | from utils.util import pad_2d_tensor, pad_list
14 |
15 |
16 | class LengthRegulator(torch.nn.Module):
17 | """Length regulator module for feed-forward Transformer.
18 |
19 | This is a module of length regulator described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
20 | The length regulator expands char or phoneme-level embedding features to frame-level by repeating each
21 | feature based on the corresponding predicted durations.
22 |
23 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
24 | https://arxiv.org/pdf/1905.09263.pdf
25 |
26 | """
27 |
28 | def __init__(self, pad_value: float = 0.0):
29 | """Initilize length regulator module.
30 |
31 | Args:
32 | pad_value (float, optional): Value used for padding.
33 |
34 | """
35 | super(LengthRegulator, self).__init__()
36 | self.pad_value = pad_value
37 |
38 | def forward(
39 | self,
40 | xs: torch.Tensor,
41 | ds: torch.Tensor,
42 | ilens: torch.Tensor,
43 | alpha: float = 1.0,
44 | ) -> torch.Tensor:
45 | """Calculate forward propagation.
46 |
47 | Args:
48 | xs (Tensor): Batch of sequences of char or phoneme embeddings (B, Tmax, D).
49 | ds (LongTensor): Batch of durations of each frame (B, T).
50 | ilens (LongTensor): Batch of input lengths (B,).
51 | alpha (float, optional): Alpha value to control speed of speech.
52 |
53 | Returns:
54 | Tensor: replicated input tensor based on durations (B, T*, D).
55 |
56 | """
57 | assert alpha > 0
58 | if alpha != 1.0:
59 | ds = torch.round(ds.float() * alpha).long()
60 | xs = [x[:ilen] for x, ilen in zip(xs, ilens)]
61 | ds = [d[:ilen] for d, ilen in zip(ds, ilens)]
62 |
63 | xs = [self._repeat_one_sequence(x, d) for x, d in zip(xs, ds)]
64 |
65 | return pad_2d_tensor(xs, 0.0)
66 |
67 | def _repeat_one_sequence(self, x: torch.Tensor, d: torch.Tensor) -> torch.Tensor:
68 | """Repeat each frame according to duration.
69 |
70 | Examples:
71 | >>> x = torch.tensor([[1], [2], [3]])
72 | tensor([[1],
73 | [2],
74 | [3]])
75 | >>> d = torch.tensor([1, 2, 3])
76 | tensor([1, 2, 3])
77 | >>> self._repeat_one_sequence(x, d)
78 | tensor([[1],
79 | [2],
80 | [2],
81 | [3],
82 | [3],
83 | [3]])
84 |
85 | """
86 | if d.sum() == 0:
87 | # logging.warn("all of the predicted durations are 0. fill 0 with 1.")
88 | d = d.fill_(1)
89 | # return torch.cat([x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0], dim=0) for torchscript
90 | out = []
91 | for x_, d_ in zip(x, d):
92 | if d_ != 0:
93 | out.append(x_.repeat(int(d_), 1))
94 |
95 | return torch.cat(out, dim=0)
96 |
--------------------------------------------------------------------------------
/core/embedding.py:
--------------------------------------------------------------------------------
1 | """Positonal Encoding Module."""
2 | import math
3 |
4 | import torch
5 |
6 |
7 | def _pre_hook(
8 | state_dict,
9 | prefix,
10 | local_metadata,
11 | strict,
12 | missing_keys,
13 | unexpected_keys,
14 | error_msgs,
15 | ):
16 | """Perform pre-hook in load_state_dict for backward compatibility.
17 |
18 | Note:
19 | We saved self.pe until v.0.5.2 but we have omitted it later.
20 | Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
21 |
22 | """
23 | k = prefix + "pe"
24 | if k in state_dict:
25 | state_dict.pop(k)
26 |
27 |
28 | class PositionalEncoding(torch.nn.Module):
29 | """Positional encoding."""
30 |
31 | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
32 | """Initialize class.
33 |
34 | :param int d_model: embedding dim
35 | :param float dropout_rate: dropout rate
36 | :param int max_len: maximum input length
37 |
38 | """
39 | super(PositionalEncoding, self).__init__()
40 | self.d_model = d_model
41 | self.xscale = math.sqrt(self.d_model)
42 | self.dropout = torch.nn.Dropout(p=dropout_rate)
43 | # self.pe = None
44 | self.register_buffer("pe", None)
45 | self.extend_pe(torch.tensor(0.0).expand(1, max_len))
46 | # self._register_load_state_dict_pre_hook(_pre_hook)
47 |
48 | def extend_pe(self, x: torch.Tensor):
49 | """Reset the positional encodings."""
50 | if self.pe is not None:
51 | if self.pe.size(1) >= x.size(1):
52 | if (
53 | self.pe.dtype != x.dtype
54 | ): # or self.pe.device != x.device: comment because of torchscript
55 | self.pe = self.pe.to(dtype=x.dtype, device=x.device)
56 | return
57 | pe = torch.zeros(x.size(1), self.d_model)
58 | position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
59 | div_term = torch.exp(
60 | torch.arange(0, self.d_model, 2, dtype=torch.float32)
61 | * -(math.log(10000.0) / self.d_model)
62 | )
63 | pe[:, 0::2] = torch.sin(position * div_term)
64 | pe[:, 1::2] = torch.cos(position * div_term)
65 | pe = pe.unsqueeze(0)
66 | self.pe = pe.to(device=x.device, dtype=x.dtype)
67 |
68 | def forward(self, x: torch.Tensor) -> torch.Tensor:
69 | """Add positional encoding.
70 |
71 | Args:
72 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
73 |
74 | Returns:
75 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
76 |
77 | """
78 | self.extend_pe(x)
79 | x = x * self.xscale + self.pe[:, : x.size(1)]
80 | return self.dropout(x)
81 |
82 |
83 | class ScaledPositionalEncoding(PositionalEncoding):
84 | """Scaled positional encoding module.
85 |
86 | See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
87 |
88 | """
89 |
90 | def __init__(self, d_model, dropout_rate, max_len=5000):
91 | """Initialize class.
92 |
93 | :param int d_model: embedding dim
94 | :param float dropout_rate: dropout rate
95 | :param int max_len: maximum input length
96 |
97 | """
98 | super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
99 | self.alpha = torch.nn.Parameter(torch.tensor(1.0))
100 |
101 | def reset_parameters(self):
102 | """Reset parameters."""
103 | self.alpha.data = torch.tensor(1.0)
104 |
105 | def forward(self, x):
106 | """Add positional encoding.
107 |
108 | Args:
109 | x (torch.Tensor): Input. Its shape is (batch, time, ...)
110 |
111 | Returns:
112 | torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
113 |
114 | """
115 | device = x.device
116 | self.extend_pe(x)
117 | # print("Devices x :", x.device)
118 | self.alpha = self.alpha.to(device=device)
119 | x = x + self.alpha * self.pe[:, : x.size(1)].to(device=device)
120 | return self.dropout(x)
121 |
--------------------------------------------------------------------------------
/core/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from core.attention import MultiHeadedAttention
4 | from core.embedding import PositionalEncoding
5 | from core.modules import MultiLayeredConv1d
6 | from core.modules import PositionwiseFeedForward
7 | from core.modules import Conv2dSubsampling
8 | from typing import Tuple, Optional
9 |
10 |
11 | class EncoderLayer(nn.Module):
12 | """Encoder layer module
13 |
14 | :param int size: input dim
15 | :param espnet.nets.pytorch_backend.core.attention.MultiHeadedAttention self_attn: self attention module
16 | :param espnet.nets.pytorch_backend.core.positionwise_feed_forward.PositionwiseFeedForward feed_forward:
17 | feed forward module
18 | :param float dropout_rate: dropout rate
19 | :param bool normalize_before: whether to use layer_norm before the first block
20 | :param bool concat_after: whether to concat attention layer's input and output
21 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
22 | if False, no additional linear will be applied. i.e. x -> x + att(x)
23 | """
24 |
25 | def __init__(
26 | self,
27 | size,
28 | self_attn,
29 | feed_forward,
30 | dropout_rate,
31 | normalize_before=True,
32 | concat_after=False,
33 | ):
34 | super(EncoderLayer, self).__init__()
35 | self.self_attn = self_attn
36 | self.feed_forward = feed_forward
37 | self.norm1 = torch.nn.LayerNorm(size)
38 | self.norm2 = torch.nn.LayerNorm(size)
39 | self.dropout = nn.Dropout(dropout_rate)
40 | self.size = size
41 | self.normalize_before = normalize_before
42 | self.concat_after = concat_after
43 | # if self.concat_after:
44 | self.concat_linear = nn.Linear(size + size, size)
45 |
46 | def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
47 | """Compute encoded features
48 |
49 | :param torch.Tensor x: encoded source features (batch, max_time_in, size)
50 | :param torch.Tensor mask: mask for x (batch, max_time_in)
51 | :rtype: Tuple[torch.Tensor, torch.Tensor]
52 | """
53 | residual = x
54 | if self.normalize_before:
55 | x = self.norm1(x)
56 | if self.concat_after:
57 | x_concat = torch.cat((x, self.self_attn(x, x, x, mask)), dim=-1)
58 | x = residual + self.concat_linear(x_concat)
59 | else:
60 | x = residual + self.dropout(self.self_attn(x, x, x, mask))
61 | if not self.normalize_before:
62 | x = self.norm1(x)
63 |
64 | residual = x
65 | if self.normalize_before:
66 | x = self.norm2(x)
67 | x = residual + self.dropout(self.feed_forward(x))
68 | if not self.normalize_before:
69 | x = self.norm2(x)
70 |
71 | return x, mask
72 |
73 |
74 | class Encoder(torch.nn.Module):
75 | """Transformer encoder module
76 |
77 | :param int idim: input dim
78 | :param int attention_dim: dimention of attention
79 | :param int attention_heads: the number of heads of multi head attention
80 | :param int linear_units: the number of units of position-wise feed forward
81 | :param int num_blocks: the number of decoder blocks
82 | :param float dropout_rate: dropout rate
83 | :param float attention_dropout_rate: dropout rate in attention
84 | :param float positional_dropout_rate: dropout rate after adding positional encoding
85 | :param str or torch.nn.Module input_layer: input layer type
86 | :param class pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
87 | :param bool normalize_before: whether to use layer_norm before the first block
88 | :param bool concat_after: whether to concat attention layer's input and output
89 | if True, additional linear will be applied. i.e. x -> x + linear(concat(x, att(x)))
90 | if False, no additional linear will be applied. i.e. x -> x + att(x)
91 | :param str positionwise_layer_type: linear of conv1d
92 | :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
93 | :param int padding_idx: padding_idx for input_layer=embed
94 | """
95 |
96 | def __init__(
97 | self,
98 | idim: int,
99 | attention_dim: int = 256,
100 | attention_heads: int = 2,
101 | linear_units: int = 2048,
102 | num_blocks: int = 4,
103 | dropout_rate: float = 0.1,
104 | positional_dropout_rate: float = 0.1,
105 | attention_dropout_rate: float = 0.0,
106 | input_layer: str = "conv2d",
107 | pos_enc_class: torch.nn.Module = PositionalEncoding,
108 | normalize_before: bool = True,
109 | concat_after: bool = False,
110 | positionwise_layer_type: str = "linear",
111 | positionwise_conv_kernel_size: int = 1,
112 | padding_idx: int = -1,
113 | ):
114 |
115 | super(Encoder, self).__init__()
116 | # if self.normalize_before:
117 | self.after_norm = torch.nn.LayerNorm(attention_dim)
118 | if input_layer == "linear":
119 | self.embed = torch.nn.Sequential(
120 | torch.nn.Linear(idim, attention_dim),
121 | torch.nn.LayerNorm(attention_dim),
122 | torch.nn.Dropout(dropout_rate),
123 | torch.nn.ReLU(),
124 | pos_enc_class(attention_dim, positional_dropout_rate),
125 | )
126 | elif input_layer == "conv2d":
127 | self.embed = Conv2dSubsampling(idim, attention_dim, dropout_rate)
128 | elif input_layer == "embed":
129 | self.embed = torch.nn.Sequential(
130 | torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
131 | pos_enc_class(attention_dim, positional_dropout_rate),
132 | )
133 | elif isinstance(input_layer, torch.nn.Module):
134 | self.embed = torch.nn.Sequential(
135 | input_layer,
136 | pos_enc_class(attention_dim, positional_dropout_rate),
137 | )
138 | elif input_layer is None:
139 | self.embed = torch.nn.Sequential(
140 | pos_enc_class(attention_dim, positional_dropout_rate)
141 | )
142 | else:
143 | raise ValueError("unknown input_layer: " + input_layer)
144 | self.normalize_before = normalize_before
145 | if positionwise_layer_type == "linear":
146 | positionwise_layer = PositionwiseFeedForward
147 | positionwise_layer_args = (attention_dim, linear_units, dropout_rate)
148 | elif positionwise_layer_type == "conv1d":
149 | positionwise_layer = MultiLayeredConv1d
150 | positionwise_layer_args = (
151 | attention_dim,
152 | linear_units,
153 | positionwise_conv_kernel_size,
154 | dropout_rate,
155 | )
156 | else:
157 | raise NotImplementedError("Support only linear or conv1d.")
158 | # self.encoders = repeat(
159 | # 4,
160 | # lambda: EncoderLayer(
161 | # attention_dim,
162 | # MultiHeadedAttention(attention_heads, attention_dim, attention_dropout_rate),
163 | # positionwise_layer(*positionwise_layer_args),
164 | # dropout_rate,
165 | # normalize_before,
166 | # concat_after
167 | # )
168 | # )
169 | self.encoders_ = nn.ModuleList(
170 | [
171 | EncoderLayer(
172 | attention_dim,
173 | MultiHeadedAttention(
174 | attention_heads, attention_dim, attention_dropout_rate
175 | ),
176 | positionwise_layer(*positionwise_layer_args),
177 | dropout_rate,
178 | normalize_before,
179 | concat_after,
180 | )
181 | for _ in range(num_blocks)
182 | ]
183 | )
184 |
185 | def forward(self, xs: torch.Tensor, masks: Optional[torch.Tensor] = None):
186 | """Embed positions in tensor
187 |
188 | :param torch.Tensor xs: input tensor
189 | :param torch.Tensor masks: input mask
190 | :return: position embedded tensor and mask
191 | :rtype Tuple[torch.Tensor, torch.Tensor]:
192 | """
193 | # if isinstance(self.embed, Conv2dSubsampling):
194 | # xs, masks = self.embed(xs, masks)
195 | # else:
196 | xs = self.embed(xs)
197 |
198 | # xs, masks = self.encoders_(xs, masks)
199 | for encoder in self.encoders_:
200 | xs, masks = encoder(xs, masks)
201 | if self.normalize_before:
202 | xs = self.after_norm(xs)
203 |
204 | return xs, masks
205 |
--------------------------------------------------------------------------------
/core/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple
3 | from core.embedding import PositionalEncoding
4 |
5 |
6 | class Conv(torch.nn.Module):
7 | """
8 | Convolution Module
9 | """
10 |
11 | def __init__(
12 | self,
13 | in_channels,
14 | out_channels,
15 | kernel_size=1,
16 | stride=1,
17 | padding=0,
18 | dilation=1,
19 | bias=True,
20 | ):
21 | """
22 | :param in_channels: dimension of input
23 | :param out_channels: dimension of output
24 | :param kernel_size: size of kernel
25 | :param stride: size of stride
26 | :param padding: size of padding
27 | :param dilation: dilation rate
28 | :param bias: boolean. if True, bias is included.
29 | :param w_init: str. weight inits with xavier initialization.
30 | """
31 | super(Conv, self).__init__()
32 |
33 | self.conv = torch.nn.Conv1d(
34 | in_channels,
35 | out_channels,
36 | kernel_size=kernel_size,
37 | stride=stride,
38 | padding=padding,
39 | dilation=dilation,
40 | bias=bias,
41 | )
42 |
43 | def forward(self, x):
44 | x = x.contiguous().transpose(1, 2)
45 | x = self.conv(x)
46 | x = x.contiguous().transpose(1, 2)
47 |
48 | return x
49 |
50 |
51 | def initialize(model, init_type="pytorch"):
52 | """Initialize Transformer module
53 |
54 | :param torch.nn.Module model: core instance
55 | :param str init_type: initialization type
56 | """
57 | if init_type == "pytorch":
58 | return
59 |
60 | # weight init
61 | for p in model.parameters():
62 | if p.dim() > 1:
63 | if init_type == "xavier_uniform":
64 | torch.nn.init.xavier_uniform_(p.data)
65 | elif init_type == "xavier_normal":
66 | torch.nn.init.xavier_normal_(p.data)
67 | elif init_type == "kaiming_uniform":
68 | torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
69 | elif init_type == "kaiming_normal":
70 | torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
71 | else:
72 | raise ValueError("Unknown initialization: " + init_type)
73 | # bias init
74 | for p in model.parameters():
75 | if p.dim() == 1:
76 | p.data.zero_()
77 |
78 | # reset some loss with default init
79 | for m in model.modules():
80 | if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm)):
81 | m.reset_parameters()
82 |
83 |
84 | class MultiSequential(torch.nn.Sequential):
85 | """Multi-input multi-output torch.nn.Sequential"""
86 |
87 | def forward(self, *args):
88 | for m in self:
89 | args = m(*args)
90 | return args
91 |
92 |
93 | def repeat(N, fn):
94 | """repeat module N times
95 |
96 | :param int N: repeat time
97 | :param function fn: function to generate module
98 | :return: repeated loss
99 | :rtype: MultiSequential
100 | """
101 | return MultiSequential(*[fn() for _ in range(N)])
102 |
103 |
104 | # def layer_norm(x: torch.Tensor, dim):
105 | # if dim == -1:
106 | # return torch.nn.LayerNorm(x)
107 | # else:
108 | # out = torch.nn.LayerNorm(x.transpose(1, -1))
109 | # return out.transpose(1, -1)
110 |
111 |
112 | class LayerNorm(torch.nn.Module):
113 | def __init__(self, nout: int):
114 | super(LayerNorm, self).__init__()
115 | self.layer_norm = torch.nn.LayerNorm(nout, eps=1e-12)
116 |
117 | def forward(self, x: torch.Tensor) -> torch.Tensor:
118 | x = self.layer_norm(x.transpose(1, -1))
119 | x = x.transpose(1, -1)
120 | return x
121 |
122 |
123 | # class LayerNorm(torch.nn.LayerNorm):
124 | # """Layer normalization module
125 | #
126 | # :param int nout: output dim size
127 | # :param int dim: dimension to be normalized
128 | # """
129 | #
130 | # def __init__(self, nout: int, dim: int=-1):
131 | # super(LayerNorm, self).__init__(nout, eps=1e-12)
132 | # self.dim = dim
133 | #
134 | # def forward(self, x: torch.Tensor) -> torch.Tensor:
135 | # """Apply layer normalization
136 | #
137 | # :param torch.Tensor x: input tensor
138 | # :return: layer normalized tensor
139 | # :rtype torch.Tensor
140 | # """
141 | # if self.dim == -1:
142 | # return super(LayerNorm, self).forward(x)
143 | # return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
144 |
145 |
146 | class Conv2dSubsampling(torch.nn.Module):
147 | """Convolutional 2D subsampling (to 1/4 length)
148 |
149 | :param int idim: input dim
150 | :param int odim: output dim
151 | :param flaot dropout_rate: dropout rate
152 | """
153 |
154 | def __init__(self, idim: int, odim: int, dropout_rate: float):
155 | super(Conv2dSubsampling, self).__init__()
156 | self.conv = torch.nn.Sequential(
157 | torch.nn.Conv2d(1, odim, 3, 2),
158 | torch.nn.ReLU(),
159 | torch.nn.Conv2d(odim, odim, 3, 2),
160 | torch.nn.ReLU(),
161 | )
162 | self.out = torch.nn.Sequential(
163 | torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
164 | PositionalEncoding(odim, dropout_rate),
165 | )
166 |
167 | def forward(
168 | self, x: torch.Tensor, x_mask: torch.Tensor
169 | ) -> Tuple[torch.Tensor, torch.Tensor]:
170 | """Subsample x
171 |
172 | :param torch.Tensor x: input tensor
173 | :param torch.Tensor x_mask: input mask
174 | :return: subsampled x and mask
175 | :rtype Tuple[torch.Tensor, torch.Tensor]
176 | """
177 | x = x.unsqueeze(1) # (b, c, t, f)
178 | x = self.conv(x)
179 | b, c, t, f = x.size()
180 | x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
181 | if x_mask is None:
182 | return x, None
183 | return x, x_mask[:, :, :-2:2][:, :, :-2:2]
184 |
185 |
186 | class PositionwiseFeedForward(torch.nn.Module):
187 | """Positionwise feed forward
188 |
189 | :param int idim: input dimenstion
190 | :param int hidden_units: number of hidden units
191 | :param float dropout_rate: dropout rate
192 | """
193 |
194 | def __init__(self, idim: int, hidden_units: int, dropout_rate: float):
195 | super(PositionwiseFeedForward, self).__init__()
196 | self.w_1 = torch.nn.Linear(idim, hidden_units)
197 | self.w_2 = torch.nn.Linear(hidden_units, idim)
198 | self.dropout = torch.nn.Dropout(dropout_rate)
199 |
200 | def forward(self, x: torch.Tensor) -> torch.Tensor:
201 | return self.w_2(self.dropout(torch.relu(self.w_1(x))))
202 |
203 |
204 | class MultiLayeredConv1d(torch.nn.Module):
205 | """Multi-layered conv1d for Transformer block.
206 |
207 | This is a module of multi-leyered conv1d designed to replace positionwise feed-forward network
208 | in Transforner block, which is introduced in `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
209 |
210 | Args:
211 | in_chans (int): Number of input channels.
212 | hidden_chans (int): Number of hidden channels.
213 | kernel_size (int): Kernel size of conv1d.
214 | dropout_rate (float): Dropout rate.
215 |
216 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
217 | https://arxiv.org/pdf/1905.09263.pdf
218 |
219 | """
220 |
221 | def __init__(
222 | self, in_chans: int, hidden_chans: int, kernel_size: int, dropout_rate: float
223 | ):
224 | super(MultiLayeredConv1d, self).__init__()
225 | self.w_1 = torch.nn.Conv1d(
226 | in_chans,
227 | hidden_chans,
228 | kernel_size,
229 | stride=1,
230 | padding=(kernel_size - 1) // 2,
231 | )
232 | self.w_2 = torch.nn.Conv1d(
233 | hidden_chans, in_chans, 1, stride=1, padding=(1 - 1) // 2
234 | )
235 | self.dropout = torch.nn.Dropout(dropout_rate)
236 |
237 | def forward(self, x: torch.Tensor) -> torch.Tensor:
238 | """Calculate forward propagation.
239 |
240 | Args:
241 | x (Tensor): Batch of input tensors (B, *, in_chans).
242 |
243 | Returns:
244 | Tensor: Batch of output tensors (B, *, hidden_chans)
245 |
246 | """
247 | x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
248 | return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
249 |
250 |
251 | class Postnet(torch.nn.Module):
252 | """Postnet module for Spectrogram prediction network.
253 | This is a module of Postnet in Spectrogram prediction network,
254 | which described in `Natural TTS Synthesis by
255 | Conditioning WaveNet on Mel Spectrogram Predictions`_.
256 | The Postnet predicts refines the predicted
257 | Mel-filterbank of the decoder,
258 | which helps to compensate the detail sturcture of spectrogram.
259 | .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
260 | https://arxiv.org/abs/1712.05884
261 | """
262 |
263 | def __init__(
264 | self,
265 | idim: int,
266 | odim: int,
267 | n_layers: int = 5,
268 | n_chans: int = 512,
269 | n_filts: int = 5,
270 | dropout_rate: float = 0.5,
271 | use_batch_norm: bool = True,
272 | ):
273 | """Initialize postnet module.
274 | Args:
275 | idim (int): Dimension of the inputs.
276 | odim (int): Dimension of the outputs.
277 | n_layers (int, optional): The number of layers.
278 | n_filts (int, optional): The number of filter size.
279 | n_units (int, optional): The number of filter channels.
280 | use_batch_norm (bool, optional): Whether to use batch normalization..
281 | dropout_rate (float, optional): Dropout rate..
282 | """
283 | super(Postnet, self).__init__()
284 | self.postnet = torch.nn.ModuleList()
285 | for layer in range(n_layers - 1):
286 | ichans = odim if layer == 0 else n_chans
287 | ochans = odim if layer == n_layers - 1 else n_chans
288 | if use_batch_norm:
289 | self.postnet += [
290 | torch.nn.Sequential(
291 | torch.nn.Conv1d(
292 | ichans,
293 | ochans,
294 | n_filts,
295 | stride=1,
296 | padding=(n_filts - 1) // 2,
297 | bias=False,
298 | ),
299 | torch.nn.BatchNorm1d(ochans),
300 | torch.nn.Tanh(),
301 | torch.nn.Dropout(dropout_rate),
302 | )
303 | ]
304 | else:
305 | self.postnet += [
306 | torch.nn.Sequential(
307 | torch.nn.Conv1d(
308 | ichans,
309 | ochans,
310 | n_filts,
311 | stride=1,
312 | padding=(n_filts - 1) // 2,
313 | bias=False,
314 | ),
315 | torch.nn.Tanh(),
316 | torch.nn.Dropout(dropout_rate),
317 | )
318 | ]
319 | ichans = n_chans if n_layers != 1 else odim
320 | if use_batch_norm:
321 | self.postnet += [
322 | torch.nn.Sequential(
323 | torch.nn.Conv1d(
324 | ichans,
325 | odim,
326 | n_filts,
327 | stride=1,
328 | padding=(n_filts - 1) // 2,
329 | bias=False,
330 | ),
331 | torch.nn.BatchNorm1d(odim),
332 | torch.nn.Dropout(dropout_rate),
333 | )
334 | ]
335 | else:
336 | self.postnet += [
337 | torch.nn.Sequential(
338 | torch.nn.Conv1d(
339 | ichans,
340 | odim,
341 | n_filts,
342 | stride=1,
343 | padding=(n_filts - 1) // 2,
344 | bias=False,
345 | ),
346 | torch.nn.Dropout(dropout_rate),
347 | )
348 | ]
349 |
350 | def forward(self, xs):
351 | """Calculate forward propagation.
352 | Args:
353 | xs (Tensor): Batch of the sequences of padded input tensors (B, idim, Tmax).
354 | Returns:
355 | Tensor: Batch of padded output tensor. (B, odim, Tmax).
356 | """
357 | for postnet in self.postnet:
358 | xs = postnet(xs)
359 | return xs
360 |
--------------------------------------------------------------------------------
/core/optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class NoamOpt(object):
5 | "Optim wrapper that implements rate."
6 |
7 | def __init__(self, model_size, factor, warmup, optimizer):
8 | self.optimizer = optimizer
9 | self._step = 0
10 | self.warmup = warmup
11 | self.factor = factor
12 | self.model_size = model_size
13 | self._rate = 0
14 |
15 | @property
16 | def param_groups(self):
17 | return self.optimizer.param_groups
18 |
19 | def step(self):
20 | "Update parameters and rate"
21 | self._step += 1
22 | rate = self.rate()
23 | for p in self.optimizer.param_groups:
24 | p["lr"] = rate
25 | self._rate = rate
26 | self.optimizer.step()
27 |
28 | def rate(self, step=None):
29 | "Implement `lrate` above"
30 | if step is None:
31 | step = self._step
32 | return (
33 | self.factor
34 | * self.model_size ** (-0.5)
35 | * min(step ** (-0.5), step * self.warmup ** (-1.5))
36 | )
37 |
38 | def zero_grad(self):
39 | self.optimizer.zero_grad()
40 |
41 | def state_dict(self):
42 | return {
43 | "_step": self._step,
44 | "warmup": self.warmup,
45 | "factor": self.factor,
46 | "model_size": self.model_size,
47 | "_rate": self._rate,
48 | "optimizer": self.optimizer.state_dict(),
49 | }
50 |
51 | def load_state_dict(self, state_dict):
52 | for key, value in state_dict.items():
53 | if key == "optimizer":
54 | self.optimizer.load_state_dict(state_dict["optimizer"])
55 | else:
56 | setattr(self, key, value)
57 |
58 |
59 | def get_std_opt(model, d_model, warmup, factor):
60 | base = torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)
61 | return NoamOpt(d_model, factor, warmup, base)
62 |
--------------------------------------------------------------------------------
/core/variance_predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from typing import Optional
4 | from core.modules import LayerNorm
5 |
6 |
7 | class VariancePredictor(torch.nn.Module):
8 | def __init__(
9 | self,
10 | idim: int,
11 | n_layers: int = 2,
12 | n_chans: int = 256,
13 | out: int = 1,
14 | kernel_size: int = 3,
15 | dropout_rate: float = 0.5,
16 | offset: float = 1.0,
17 | ):
18 | super(VariancePredictor, self).__init__()
19 | self.offset = offset
20 | self.conv = torch.nn.ModuleList()
21 | for idx in range(n_layers):
22 | in_chans = idim if idx == 0 else n_chans
23 | self.conv += [
24 | torch.nn.Sequential(
25 | torch.nn.Conv1d(
26 | in_chans,
27 | n_chans,
28 | kernel_size,
29 | stride=1,
30 | padding=(kernel_size - 1) // 2,
31 | ),
32 | torch.nn.ReLU(),
33 | LayerNorm(n_chans),
34 | torch.nn.Dropout(dropout_rate),
35 | )
36 | ]
37 | self.linear = torch.nn.Linear(n_chans, out)
38 |
39 | def _forward(
40 | self,
41 | xs: torch.Tensor,
42 | is_inference: bool = False,
43 | is_log_output: bool = False,
44 | alpha: float = 1.0,
45 | ) -> torch.Tensor:
46 | xs = xs.transpose(1, -1) # (B, idim, Tmax)
47 | for f in self.conv:
48 | xs = f(xs) # (B, C, Tmax)
49 |
50 | # NOTE: calculate in log domain
51 | xs = self.linear(xs.transpose(1, -1)).squeeze(-1) # (B, Tmax)
52 |
53 | if is_inference and is_log_output:
54 | # # NOTE: calculate in linear domain
55 | xs = torch.clamp(
56 | torch.round(xs.exp() - self.offset), min=0
57 | ).long() # avoid negative value
58 | xs = xs * alpha
59 |
60 | return xs
61 |
62 | def forward(
63 | self, xs: torch.Tensor, x_masks: Optional[torch.Tensor] = None
64 | ) -> torch.Tensor:
65 | """Calculate forward propagation.
66 |
67 | Args:
68 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
69 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
70 |
71 | Returns:
72 | Tensor: Batch of predicted durations in log domain (B, Tmax).
73 |
74 | """
75 | xs = self._forward(xs)
76 | if x_masks is not None:
77 | xs = xs.masked_fill(x_masks, 0.0)
78 | return xs
79 |
80 | def inference(
81 | self, xs: torch.Tensor, is_log_output: bool = False, alpha: float = 1.0
82 | ) -> torch.Tensor:
83 | """Inference duration.
84 |
85 | Args:
86 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
87 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
88 |
89 | Returns:
90 | LongTensor: Batch of predicted durations in linear domain (B, Tmax).
91 |
92 | """
93 | return self._forward(
94 | xs, is_inference=True, is_log_output=is_log_output, alpha=alpha
95 | )
96 |
97 |
98 | class EnergyPredictor(torch.nn.Module):
99 | def __init__(
100 | self,
101 | idim,
102 | n_layers=2,
103 | n_chans=256,
104 | kernel_size=3,
105 | dropout_rate=0.1,
106 | offset=1.0,
107 | min=0,
108 | max=0,
109 | n_bins=256,
110 | ):
111 | """Initilize Energy predictor module.
112 |
113 | Args:
114 | idim (int): Input dimension.
115 | n_layers (int, optional): Number of convolutional layers.
116 | n_chans (int, optional): Number of channels of convolutional layers.
117 | kernel_size (int, optional): Kernel size of convolutional layers.
118 | dropout_rate (float, optional): Dropout rate.
119 | offset (float, optional): Offset value to avoid nan in log domain.
120 |
121 | """
122 | super(EnergyPredictor, self).__init__()
123 | # self.bins = torch.linspace(min, max, n_bins - 1).cuda()
124 | self.register_buffer("energy_bins", torch.linspace(min, max, n_bins - 1))
125 | self.predictor = VariancePredictor(idim)
126 |
127 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor):
128 | """Calculate forward propagation.
129 |
130 | Args:
131 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
132 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
133 |
134 | Returns:
135 | Tensor: Batch of predicted durations in log domain (B, Tmax).
136 |
137 | """
138 | return self.predictor(xs, x_masks)
139 |
140 | def inference(self, xs: torch.Tensor, alpha: float = 1.0):
141 | """Inference duration.
142 |
143 | Args:
144 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
145 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
146 |
147 | Returns:
148 | LongTensor: Batch of predicted durations in linear domain (B, Tmax).
149 |
150 | """
151 | out = self.predictor.inference(xs, False, alpha=alpha)
152 | return self.to_one_hot(out) # Need to do One hot code
153 |
154 | def to_one_hot(self, x):
155 | # e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
156 | # For pytorch > = 1.6.0
157 |
158 | quantize = torch.bucketize(x, self.energy_bins).to(device=x.device) # .cuda()
159 | return F.one_hot(quantize.long(), 256).float()
160 |
161 |
162 | class PitchPredictor(torch.nn.Module):
163 | def __init__(
164 | self,
165 | idim,
166 | n_layers=2,
167 | n_chans=384,
168 | kernel_size=3,
169 | dropout_rate=0.1,
170 | offset=1.0,
171 | min=0,
172 | max=0,
173 | n_bins=256,
174 | ):
175 | """Initilize pitch predictor module.
176 |
177 | Args:
178 | idim (int): Input dimension.
179 | n_layers (int, optional): Number of convolutional layers.
180 | n_chans (int, optional): Number of channels of convolutional layers.
181 | kernel_size (int, optional): Kernel size of convolutional layers.
182 | dropout_rate (float, optional): Dropout rate.
183 | offset (float, optional): Offset value to avoid nan in log domain.
184 |
185 | """
186 | super(PitchPredictor, self).__init__()
187 | # self.bins = torch.exp(torch.linspace(torch.log(torch.tensor(min)), torch.log(torch.tensor(max)), n_bins - 1)).cuda()
188 | self.register_buffer(
189 | "pitch_bins",
190 | torch.exp(
191 | torch.linspace(
192 | torch.log(torch.tensor(min)),
193 | torch.log(torch.tensor(max)),
194 | n_bins - 1,
195 | )
196 | ),
197 | )
198 | self.predictor = VariancePredictor(idim)
199 |
200 | def forward(self, xs: torch.Tensor, x_masks: torch.Tensor):
201 | """Calculate forward propagation.
202 |
203 | Args:
204 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
205 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
206 |
207 | Returns:
208 | Tensor: Batch of predicted durations in log domain (B, Tmax).
209 |
210 | """
211 | return self.predictor(xs, x_masks)
212 |
213 | def inference(self, xs: torch.Tensor, alpha: float = 1.0):
214 | """Inference duration.
215 |
216 | Args:
217 | xs (Tensor): Batch of input sequences (B, Tmax, idim).
218 | x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax).
219 |
220 | Returns:
221 | LongTensor: Batch of predicted durations in linear domain (B, Tmax).
222 |
223 | """
224 | out = self.predictor.inference(xs, False, alpha=alpha)
225 | return self.to_one_hot(out)
226 |
227 | def to_one_hot(self, x: torch.Tensor):
228 | # e = de_norm_mean_std(e, hp.e_mean, hp.e_std)
229 | # For pytorch > = 1.6.0
230 |
231 | quantize = torch.bucketize(x, self.pitch_bins).to(device=x.device) # .cuda()
232 | return F.one_hot(quantize.long(), 256).float()
233 |
234 |
235 | class PitchPredictorLoss(torch.nn.Module):
236 | """Loss function module for duration predictor.
237 |
238 | The loss value is Calculated in log domain to make it Gaussian.
239 |
240 | """
241 |
242 | def __init__(self, offset=1.0):
243 | """Initilize duration predictor loss module.
244 |
245 | Args:
246 | offset (float, optional): Offset value to avoid nan in log domain.
247 |
248 | """
249 | super(PitchPredictorLoss, self).__init__()
250 | self.criterion = torch.nn.MSELoss()
251 | self.offset = offset
252 |
253 | def forward(self, outputs, targets):
254 | """Calculate forward propagation.
255 |
256 | Args:
257 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
258 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
259 |
260 | Returns:
261 | Tensor: Mean squared error loss value.
262 |
263 | Note:
264 | `outputs` is in log domain but `targets` is in linear domain.
265 |
266 | """
267 | # NOTE: We convert the output in log domain low error value
268 | # print("Output :", outputs[0])
269 | # print("Before Output :", targets[0])
270 | # targets = torch.log(targets.float() + self.offset)
271 | # print("Before Output :", targets[0])
272 | # outputs = torch.log(outputs.float() + self.offset)
273 | loss = self.criterion(outputs, targets)
274 | # print(loss)
275 | return loss
276 |
277 |
278 | class EnergyPredictorLoss(torch.nn.Module):
279 | """Loss function module for duration predictor.
280 |
281 | The loss value is Calculated in log domain to make it Gaussian.
282 |
283 | """
284 |
285 | def __init__(self, offset=1.0):
286 | """Initilize duration predictor loss module.
287 |
288 | Args:
289 | offset (float, optional): Offset value to avoid nan in log domain.
290 |
291 | """
292 | super(EnergyPredictorLoss, self).__init__()
293 | self.criterion = torch.nn.MSELoss()
294 | self.offset = offset
295 |
296 | def forward(self, outputs, targets):
297 | """Calculate forward propagation.
298 |
299 | Args:
300 | outputs (Tensor): Batch of prediction durations in log domain (B, T)
301 | targets (LongTensor): Batch of groundtruth durations in linear domain (B, T)
302 |
303 | Returns:
304 | Tensor: Mean squared error loss value.
305 |
306 | Note:
307 | `outputs` is in log domain but `targets` is in linear domain.
308 |
309 | """
310 | # NOTE: outputs is in log domain while targets in linear
311 | # targets = torch.log(targets.float() + self.offset)
312 | loss = self.criterion(outputs, targets)
313 |
314 | return loss
315 |
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/audio_processing.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import librosa
4 | from scipy.signal import lfilter
5 | import pyworld as pw
6 | import torch
7 | from scipy.signal import get_window
8 | import librosa.util as librosa_util
9 |
10 |
11 | def label_2_float(x, bits):
12 | return 2 * x / (2 ** bits - 1.0) - 1.0
13 |
14 |
15 | def float_2_label(x, bits):
16 | assert abs(x).max() <= 1.0
17 | x = (x + 1.0) * (2 ** bits - 1) / 2
18 | return x.clip(0, 2 ** bits - 1)
19 |
20 |
21 | def load_wav(path, hp):
22 | return librosa.load(path, sr=hp.audio.sample_rate)[0]
23 |
24 |
25 | def save_wav(x, path, hp):
26 | librosa.output.write_wav(path, x.astype(np.float32), sr=hp.audio.sample_rate)
27 |
28 |
29 | def split_signal(x):
30 | unsigned = x + 2 ** 15
31 | coarse = unsigned // 256
32 | fine = unsigned % 256
33 | return coarse, fine
34 |
35 |
36 | def combine_signal(coarse, fine):
37 | return coarse * 256 + fine - 2 ** 15
38 |
39 |
40 | def encode_16bits(x):
41 | return np.clip(x * 2 ** 15, -(2 ** 15), 2 ** 15 - 1).astype(np.int16)
42 |
43 |
44 | mel_basis = None
45 |
46 |
47 | def energy(y):
48 | # Extract energy
49 | S = librosa.magphase(stft(y))[0]
50 | e = np.sqrt(np.sum(S ** 2, axis=0)) # np.linalg.norm(S, axis=0)
51 | return e.squeeze() # (Number of frames) => (654,)
52 |
53 |
54 | def pitch(y, hp):
55 | # Extract Pitch/f0 from raw waveform using PyWORLD
56 | y = y.astype(np.float64)
57 | """
58 | f0_floor : float
59 | Lower F0 limit in Hz.
60 | Default: 71.0
61 | f0_ceil : float
62 | Upper F0 limit in Hz.
63 | Default: 800.0
64 | """
65 | f0, timeaxis = pw.dio(
66 | y,
67 | hp.audio.sample_rate,
68 | frame_period=hp.audio.hop_length / hp.audio.sample_rate * 1000,
69 | ) # For hop size 256 frame period is 11.6 ms
70 | return f0 # (Number of Frames) = (654,)
71 |
72 |
73 | def linear_to_mel(spectrogram, hp):
74 | global mel_basis
75 | if mel_basis is None:
76 | mel_basis = build_mel_basis(hp)
77 | return np.dot(mel_basis, spectrogram)
78 |
79 |
80 | def build_mel_basis(hp):
81 | return librosa.filters.mel(
82 | hp.audio.sample_rate,
83 | hp.audio.n_fft,
84 | n_mels=hp.audio.num_mels,
85 | fmin=hp.audio.fmin,
86 | )
87 |
88 |
89 | def normalize(S, hp):
90 | return np.clip((S - hp.audio.min_level_db) / -hp.audio.min_level_db, 0, 1)
91 |
92 |
93 | def denormalize(S, hp):
94 | return (np.clip(S, 0, 1) * -hp.audio.min_level_db) + hp.audio.min_level_db
95 |
96 |
97 | def amp_to_db(x):
98 | return 20 * np.log10(np.maximum(1e-5, x))
99 |
100 |
101 | def db_to_amp(x):
102 | return np.power(10.0, x * 0.05)
103 |
104 |
105 | def spectrogram(y, hp):
106 | D = stft(y, hp)
107 | S = amp_to_db(np.abs(D)) - hp.audio.ref_level_db
108 | return normalize(S, hp)
109 |
110 |
111 | def melspectrogram(y, hp):
112 | D = stft(y, hp)
113 | S = amp_to_db(linear_to_mel(np.abs(D), hp))
114 | return normalize(S, hp)
115 |
116 |
117 | def stft(y, hp):
118 | return librosa.stft(
119 | y=y,
120 | n_fft=hp.audio.n_fft,
121 | hop_length=hp.audio.hop_length,
122 | win_length=hp.audio.win_length,
123 | )
124 |
125 |
126 | def pre_emphasis(x, hp):
127 | return lfilter([1, -hp.audio.preemphasis], [1], x)
128 |
129 |
130 | def de_emphasis(x, hp):
131 | return lfilter([1], [1, -hp.audio.preemphasis], x)
132 |
133 |
134 | def encode_mu_law(x, mu):
135 | mu = mu - 1
136 | fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
137 | return np.floor((fx + 1) / 2 * mu + 0.5)
138 |
139 |
140 | def decode_mu_law(y, mu, from_labels=True):
141 | # TODO : get rid of log2 - makes no sense
142 | if from_labels:
143 | y = label_2_float(y, math.log2(mu))
144 | mu = mu - 1
145 | x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
146 | return x
147 |
148 |
149 | def reconstruct_waveform(mel, hp, n_iter=32):
150 | """Uses Griffin-Lim phase reconstruction to convert from a normalized
151 | mel spectrogram back into a waveform."""
152 | denormalized = denormalize(mel)
153 | amp_mel = db_to_amp(denormalized)
154 | S = librosa.feature.inverse.mel_to_stft(
155 | amp_mel,
156 | power=1,
157 | sr=hp.audio.sample_rate,
158 | n_fft=hp.audio.n_fft,
159 | fmin=hp.audio.fmin,
160 | )
161 | wav = librosa.core.griffinlim(
162 | S, n_iter=n_iter, hop_length=hp.audio.hop_length, win_length=hp.audio.win_length
163 | )
164 | return wav
165 |
166 |
167 | def quantize_input(input, min, max, num_bins=256):
168 | bins = np.linspace(min, max, num=num_bins)
169 | quantize = np.digitize(input, bins)
170 | return quantize
171 |
172 |
173 | def window_sumsquare(
174 | window,
175 | n_frames,
176 | hop_length=200,
177 | win_length=800,
178 | n_fft=800,
179 | dtype=np.float32,
180 | norm=None,
181 | ):
182 | """
183 | # from librosa 0.6
184 | Compute the sum-square envelope of a window function at a given hop length.
185 | This is used to estimate modulation effects induced by windowing
186 | observations in short-time fourier transforms.
187 | Parameters
188 | ----------
189 | window : string, tuple, number, callable, or list-like
190 | Window specification, as in `get_window`
191 | n_frames : int > 0
192 | The number of analysis frames
193 | hop_length : int > 0
194 | The number of samples to advance between frames
195 | win_length : [optional]
196 | The length of the window function. By default, this matches `n_fft`.
197 | n_fft : int > 0
198 | The length of each analysis frame.
199 | dtype : np.dtype
200 | The data type of the output
201 | Returns
202 | -------
203 | wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
204 | The sum-squared envelope of the window function
205 | """
206 | if win_length is None:
207 | win_length = n_fft
208 |
209 | n = n_fft + hop_length * (n_frames - 1)
210 | x = np.zeros(n, dtype=dtype)
211 |
212 | # Compute the squared window at the desired length
213 | win_sq = get_window(window, win_length, fftbins=True)
214 | win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
215 | win_sq = librosa_util.pad_center(win_sq, n_fft)
216 |
217 | # Fill the envelope
218 | for i in range(n_frames):
219 | sample = i * hop_length
220 | x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
221 | return x
222 |
223 |
224 | def griffin_lim(magnitudes, stft_fn, n_iters=30):
225 | """
226 | PARAMS
227 | ------
228 | magnitudes: spectrogram magnitudes
229 | stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
230 | """
231 |
232 | angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
233 | angles = angles.astype(np.float32)
234 | angles = torch.autograd.Variable(torch.from_numpy(angles).cuda())
235 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
236 |
237 | for i in range(n_iters):
238 | _, angles = stft_fn.transform(signal)
239 | signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
240 | return signal
241 |
242 |
243 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
244 | """
245 | PARAMS
246 | ------
247 | C: compression factor
248 | """
249 | return torch.log(torch.clamp(x, min=clip_val) * C)
250 |
251 |
252 | def dynamic_range_decompression(x, C=1):
253 | """
254 | PARAMS
255 | ------
256 | C: compression factor used to compress
257 | """
258 | return torch.exp(x) / C
259 |
--------------------------------------------------------------------------------
/dataset/dataloader.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.utils.data import Dataset, DataLoader
4 | from torch.utils.data.sampler import Sampler
5 | from dataset.texts import phonemes_to_sequence
6 | import numpy as np
7 | from dataset.texts import text_to_sequence
8 | from utils.util import pad_list, str_to_int_list, remove_outlier
9 |
10 |
11 | def get_tts_dataset(path, batch_size, hp, valid=False):
12 |
13 | if valid:
14 | file_ = hp.data.valid_filelist
15 | pin_mem = False
16 | num_workers = 0
17 | shuffle = False
18 | else:
19 | file_ = hp.data.train_filelist
20 | pin_mem = True
21 | num_workers = 4
22 | shuffle = True
23 | train_dataset = TTSDataset(
24 | path, file_, hp.train.use_phonemes, hp.data.tts_cleaner_names, hp.train.eos
25 | )
26 |
27 | train_set = DataLoader(
28 | train_dataset,
29 | collate_fn=collate_tts,
30 | batch_size=batch_size,
31 | num_workers=num_workers,
32 | shuffle=shuffle,
33 | pin_memory=pin_mem,
34 | )
35 | return train_set
36 |
37 |
38 | class TTSDataset(Dataset):
39 | def __init__(self, path, file_, use_phonemes, tts_cleaner_names, eos):
40 | self.path = path
41 | with open("{}".format(file_), encoding="utf-8") as f:
42 | self._metadata = [line.strip().split("|") for line in f]
43 | self.use_phonemes = use_phonemes
44 | self.tts_cleaner_names = tts_cleaner_names
45 | self.eos = eos
46 |
47 | def __getitem__(self, index):
48 | id = self._metadata[index][4].split(".")[0]
49 | x_ = self._metadata[index][3].split()
50 | if self.use_phonemes:
51 | x = phonemes_to_sequence(x_)
52 | else:
53 | x = text_to_sequence(x_, self.tts_cleaner_names, self.eos)
54 | mel = np.load(f"{self.path}mels/{id}.npy")
55 | durations = str_to_int_list(self._metadata[index][2])
56 | e = remove_outlier(
57 | np.load(f"{self.path}energy/{id}.npy")
58 | ) # self._norm_mean_std(np.load(f'{self.path}energy/{id}.npy'), self.e_mean, self.e_std, True)
59 | p = remove_outlier(
60 | np.load(f"{self.path}pitch/{id}.npy")
61 | ) # self._norm_mean_std(np.load(f'{self.path}pitch/{id}.npy'), self.f0_mean, self.f0_std, True)
62 | mel_len = mel.shape[1]
63 | durations = durations[: len(x)]
64 | durations[-1] = durations[-1] + (mel.shape[1] - sum(durations))
65 | assert mel.shape[1] == sum(durations)
66 | return (
67 | np.array(x),
68 | mel.T,
69 | id,
70 | mel_len,
71 | np.array(durations),
72 | e,
73 | p,
74 | ) # Mel [T, num_mel]
75 |
76 | def __len__(self):
77 | return len(self._metadata)
78 |
79 | def _norm_mean_std(self, x, mean, std, is_remove_outlier=False):
80 | if is_remove_outlier:
81 | x = remove_outlier(x)
82 | zero_idxs = np.where(x == 0.0)[0]
83 | x = (x - mean) / std
84 | x[zero_idxs] = 0.0
85 | return x
86 |
87 |
88 | def pad1d(x, max_len):
89 | return np.pad(x, (0, max_len - len(x)), mode="constant")
90 |
91 |
92 | def pad2d(x, max_len):
93 | return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant")
94 |
95 |
96 | def collate_tts(batch):
97 |
98 | ilens = torch.from_numpy(np.array([x[0].shape[0] for x in batch])).long()
99 | olens = torch.from_numpy(np.array([y[1].shape[0] for y in batch])).long()
100 | ids = [x[2] for x in batch]
101 |
102 | # perform padding and conversion to tensor
103 | inputs = pad_list([torch.from_numpy(x[0]).long() for x in batch], 0)
104 | mels = pad_list([torch.from_numpy(y[1]).float() for y in batch], 0)
105 |
106 | durations = pad_list([torch.from_numpy(x[4]).long() for x in batch], 0)
107 | energys = pad_list([torch.from_numpy(y[5]).float() for y in batch], 0)
108 | pitches = pad_list([torch.from_numpy(y[6]).float() for y in batch], 0)
109 |
110 | # make labels for stop prediction
111 | labels = mels.new_zeros(mels.size(0), mels.size(1))
112 | for i, l in enumerate(olens):
113 | labels[i, l - 1 :] = 1.0
114 |
115 | # scale spectrograms to -4 <--> 4
116 | # mels = (mels * 8.) - 4
117 |
118 | return inputs, ilens, mels, labels, olens, ids, durations, energys, pitches
119 |
120 |
121 | class BinnedLengthSampler(Sampler):
122 | def __init__(self, lengths, batch_size, bin_size):
123 | _, self.idx = torch.sort(torch.tensor(lengths).long())
124 | self.batch_size = batch_size
125 | self.bin_size = bin_size
126 | assert self.bin_size % self.batch_size == 0
127 |
128 | def __iter__(self):
129 | # Need to change to numpy since there's a bug in random.shuffle(tensor)
130 | # TODO : Post an issue on pytorch repo
131 | idx = self.idx.numpy()
132 | bins = []
133 |
134 | for i in range(len(idx) // self.bin_size):
135 | this_bin = idx[i * self.bin_size : (i + 1) * self.bin_size]
136 | random.shuffle(this_bin)
137 | bins += [this_bin]
138 |
139 | random.shuffle(bins)
140 | binned_idx = np.stack(bins).reshape(-1)
141 |
142 | if len(binned_idx) < len(idx):
143 | last_bin = idx[len(binned_idx) :]
144 | random.shuffle(last_bin)
145 | binned_idx = np.concatenate([binned_idx, last_bin])
146 |
147 | return iter(torch.tensor(binned_idx).long())
148 |
149 | def __len__(self):
150 | return len(self.idx)
151 |
--------------------------------------------------------------------------------
/dataset/ljspeech.py:
--------------------------------------------------------------------------------
1 | from utils.util import get_files
2 |
3 |
4 | def ljspeech(path, hp):
5 |
6 | csv_file = get_files(path, extension=".csv")
7 |
8 | assert len(csv_file) == 1
9 |
10 | wavs = []
11 | # texts = []
12 | # encode = []
13 |
14 | with open(csv_file[0], encoding="utf-8") as f_:
15 | # if 'phoneme_cleaners' in hp.tts_cleaner_names:
16 | # print("Cleaner : {} Language Code : {}\n".format(hp.tts_cleaner_names[0],hp.phoneme_language))
17 | # for line in f :
18 | # split = line.split('|')
19 | # text_dict[split[0]] = text2phone(split[-1].strip(),hp.phoneme_language)
20 | # else:
21 | print("Cleaner : {} \n".format(hp.tts_cleaner_names))
22 | for line in f_:
23 | sub = {}
24 | split = line.split("|")
25 | t = split[-1].strip().upper()
26 | # t = t.replace('"', '')
27 | # t = t.replace('-', ' ')
28 | # t = t.replace(';','')
29 | # t = t.replace('(', '')
30 | # t = t.replace(')', '')
31 | # t = t.replace(':', '')
32 | # t = re.sub('[^A-Za-z0-9.!?,\' ]+', '', t)
33 | if len(t) > 0:
34 | wavs.append(split[0].strip())
35 | # texts.append(t)
36 | # encode.append(text_to_sequence(t, hp.tts_cleaner_names))
37 | # with open(os.path.join(data_dir, 'train.txt'), 'w', encoding='utf-8') as f:
38 | # for w, t, e in zip(wavs, texts, encode):
39 | # f.write('{}|{}|{}'.format(w,e,t) + '\n')
40 |
41 | return wavs # , texts, encode
42 |
43 |
44 | if __name__ == "__main__":
45 | ljspeech("metadata.csv", ["english_cleaners"])
46 |
--------------------------------------------------------------------------------
/dataset/texts/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | import re
3 | from dataset.texts import cleaners
4 | from dataset.texts.symbols import (
5 | symbols,
6 | _eos,
7 | phonemes_symbols,
8 | PAD,
9 | EOS,
10 | _PHONEME_SEP,
11 | )
12 | from dataset.texts.dict_ import symbols_
13 | import nltk
14 | from g2p_en import G2p
15 |
16 | # Mappings from symbol to numeric ID and vice versa:
17 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
18 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
19 |
20 | # Regular expression matching text enclosed in curly braces:
21 | _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
22 |
23 | symbols_inv = {v: k for k, v in symbols_.items()}
24 |
25 | valid_symbols = [
26 | "AA",
27 | "AA1",
28 | "AE",
29 | "AE0",
30 | "AE1",
31 | "AH",
32 | "AH0",
33 | "AH1",
34 | "AO",
35 | "AO1",
36 | "AW",
37 | "AW0",
38 | "AW1",
39 | "AY",
40 | "AY0",
41 | "AY1",
42 | "B",
43 | "CH",
44 | "D",
45 | "DH",
46 | "EH",
47 | "EH0",
48 | "EH1",
49 | "ER",
50 | "EY",
51 | "EY0",
52 | "EY1",
53 | "F",
54 | "G",
55 | "HH",
56 | "IH",
57 | "IH0",
58 | "IH1",
59 | "IY",
60 | "IY0",
61 | "IY1",
62 | "JH",
63 | "K",
64 | "L",
65 | "M",
66 | "N",
67 | "NG",
68 | "OW",
69 | "OW0",
70 | "OW1",
71 | "OY",
72 | "OY0",
73 | "OY1",
74 | "P",
75 | "R",
76 | "S",
77 | "SH",
78 | "T",
79 | "TH",
80 | "UH",
81 | "UH0",
82 | "UH1",
83 | "UW",
84 | "UW0",
85 | "UW1",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | "pau",
92 | "sil",
93 | "spn"
94 | ]
95 |
96 |
97 | def pad_with_eos_bos(_sequence):
98 | return _sequence + [_symbol_to_id[_eos]]
99 |
100 |
101 | def text_to_sequence(text, cleaner_names, eos):
102 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
103 | The text can optionally have ARPAbet sequences enclosed in curly braces embedded
104 | in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
105 | Args:
106 | text: string to convert to a sequence
107 | cleaner_names: names of the cleaner functions to run the text through
108 | Returns:
109 | List of integers corresponding to the symbols in the text
110 | """
111 | sequence = []
112 | if eos:
113 | text = text + "~"
114 | try:
115 | sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
116 | except KeyError:
117 | print("text : ", text)
118 | exit(0)
119 |
120 | return sequence
121 |
122 |
123 | def sequence_to_text(sequence):
124 | """Converts a sequence of IDs back to a string"""
125 | result = ""
126 | for symbol_id in sequence:
127 | if symbol_id in symbols_inv:
128 | s = symbols_inv[symbol_id]
129 | # Enclose ARPAbet back in curly braces:
130 | if len(s) > 1 and s[0] == "@":
131 | s = "{%s}" % s[1:]
132 | result += s
133 | return result.replace("}{", " ")
134 |
135 |
136 | def _clean_text(text, cleaner_names):
137 | for name in cleaner_names:
138 | cleaner = getattr(cleaners, name)
139 | if not cleaner:
140 | raise Exception("Unknown cleaner: %s" % name)
141 | text = cleaner(text)
142 | return text
143 |
144 |
145 | def _symbols_to_sequence(symbols):
146 | return [symbols_[s.upper()] for s in symbols]
147 |
148 |
149 | def _arpabet_to_sequence(text):
150 | return _symbols_to_sequence(["@" + s for s in text.split()])
151 |
152 |
153 | def _should_keep_symbol(s):
154 | return s in _symbol_to_id and s != "_" and s != "~"
155 |
156 |
157 | # For phonemes
158 | _phoneme_to_id = {s: i for i, s in enumerate(valid_symbols)}
159 | _id_to_phoneme = {i: s for i, s in enumerate(valid_symbols)}
160 |
161 |
162 | def _should_keep_token(token, token_dict):
163 | return (
164 | token in token_dict
165 | and token != PAD
166 | and token != EOS
167 | and token != _phoneme_to_id[PAD]
168 | and token != _phoneme_to_id[EOS]
169 | )
170 |
171 |
172 | def phonemes_to_sequence(phonemes):
173 | string = phonemes.split() if isinstance(phonemes, str) else phonemes
174 | # string.append(EOS)
175 | sequence = list(map(convert_phoneme_CMU, string))
176 | sequence = [_phoneme_to_id[s] for s in sequence]
177 | # if _should_keep_token(s, _phoneme_to_id)]
178 | return sequence
179 |
180 |
181 | def sequence_to_phonemes(sequence, use_eos=False):
182 | string = [_id_to_phoneme[idx] for idx in sequence]
183 | # if _should_keep_token(idx, _id_to_phoneme)]
184 | string = _PHONEME_SEP.join(string)
185 | if use_eos:
186 | string = string.replace(EOS, "")
187 | return string
188 |
189 |
190 | def convert_phoneme_CMU(phoneme):
191 | REMAPPING = {
192 | 'AA0': 'AA1',
193 | 'AA2': 'AA1',
194 | 'AE2': 'AE1',
195 | 'AH2': 'AH1',
196 | 'AO0': 'AO1',
197 | 'AO2': 'AO1',
198 | 'AW2': 'AW1',
199 | 'AY2': 'AY1',
200 | 'EH2': 'EH1',
201 | 'ER0': 'EH1',
202 | 'ER1': 'EH1',
203 | 'ER2': 'EH1',
204 | 'EY2': 'EY1',
205 | 'IH2': 'IH1',
206 | 'IY2': 'IY1',
207 | 'OW2': 'OW1',
208 | 'OY2': 'OY1',
209 | 'UH2': 'UH1',
210 | 'UW2': 'UW1',
211 | }
212 | return REMAPPING.get(phoneme, phoneme)
213 |
214 |
215 | def text_to_phonemes(text, custom_words={}):
216 | """
217 | Convert text into ARPAbet.
218 | For known words use CMUDict; for the rest try 'espeak' (to IPA) followed by 'listener'.
219 | :param text: str, input text.
220 | :param custom_words:
221 | dict {str: list of str}, optional
222 | Pronounciations (a list of ARPAbet phonemes) you'd like to override.
223 | Example: {'word': ['W', 'EU1', 'R', 'D']}
224 | :return: list of str, phonemes
225 | """
226 | g2p = G2p()
227 |
228 | """def convert_phoneme_CMU(phoneme):
229 | REMAPPING = {
230 | 'AA0': 'AA1',
231 | 'AA2': 'AA1',
232 | 'AE2': 'AE1',
233 | 'AH2': 'AH1',
234 | 'AO0': 'AO1',
235 | 'AO2': 'AO1',
236 | 'AW2': 'AW1',
237 | 'AY2': 'AY1',
238 | 'EH2': 'EH1',
239 | 'ER0': 'EH1',
240 | 'ER1': 'EH1',
241 | 'ER2': 'EH1',
242 | 'EY2': 'EY1',
243 | 'IH2': 'IH1',
244 | 'IY2': 'IY1',
245 | 'OW2': 'OW1',
246 | 'OY2': 'OY1',
247 | 'UH2': 'UH1',
248 | 'UW2': 'UW1',
249 | }
250 | return REMAPPING.get(phoneme, phoneme)
251 | """
252 |
253 | def convert_phoneme_listener(phoneme):
254 | VOWELS = ['A', 'E', 'I', 'O', 'U']
255 | if phoneme[0] in VOWELS:
256 | phoneme += '1'
257 | return phoneme # convert_phoneme_CMU(phoneme)
258 |
259 | try:
260 | known_words = nltk.corpus.cmudict.dict()
261 | except LookupError:
262 | nltk.download("cmudict")
263 | known_words = nltk.corpus.cmudict.dict()
264 |
265 | for word, phonemes in custom_words.items():
266 | known_words[word.lower()] = [phonemes]
267 |
268 | words = nltk.tokenize.WordPunctTokenizer().tokenize(text.lower())
269 |
270 | phonemes = []
271 | PUNCTUATION = "!?.,-:;\"'()"
272 | for word in words:
273 | if all(c in PUNCTUATION for c in word):
274 | pronounciation = ["pau"]
275 | elif word in known_words:
276 | pronounciation = known_words[word][0]
277 | pronounciation = list(
278 | pronounciation
279 | ) # map(convert_phoneme_CMU, pronounciation))
280 | else:
281 | pronounciation = g2p(word)
282 | pronounciation = list(
283 | pronounciation
284 | ) # (map(convert_phoneme_CMU, pronounciation))
285 |
286 | phonemes += pronounciation
287 |
288 | return phonemes
289 |
--------------------------------------------------------------------------------
/dataset/texts/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | """
14 |
15 |
16 | # Regular expression matching whitespace:
17 | import re
18 | from unidecode import unidecode
19 | from .numbers import normalize_numbers
20 |
21 | _whitespace_re = re.compile(r"\s+")
22 | punctuations = """+-!()[]{};:'"\<>/?@#^&*_~"""
23 |
24 | # List of (regular expression, replacement) pairs for abbreviations:
25 | _abbreviations = [
26 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
27 | for x in [
28 | ("mrs", "misess"),
29 | ("mr", "mister"),
30 | ("dr", "doctor"),
31 | ("st", "saint"),
32 | ("co", "company"),
33 | ("jr", "junior"),
34 | ("maj", "major"),
35 | ("gen", "general"),
36 | ("drs", "doctors"),
37 | ("rev", "reverend"),
38 | ("lt", "lieutenant"),
39 | ("hon", "honorable"),
40 | ("sgt", "sergeant"),
41 | ("capt", "captain"),
42 | ("esq", "esquire"),
43 | ("ltd", "limited"),
44 | ("col", "colonel"),
45 | ("ft", "fort"),
46 | ]
47 | ]
48 |
49 |
50 | def expand_abbreviations(text):
51 | for regex, replacement in _abbreviations:
52 | text = re.sub(regex, replacement, text)
53 | return text
54 |
55 |
56 | def expand_numbers(text):
57 | return normalize_numbers(text)
58 |
59 |
60 | def lowercase(text):
61 | return text.lower()
62 |
63 |
64 | def collapse_whitespace(text):
65 | return re.sub(_whitespace_re, " ", text)
66 |
67 |
68 | def convert_to_ascii(text):
69 | return unidecode(text)
70 |
71 |
72 | def basic_cleaners(text):
73 | """Basic pipeline that lowercases and collapses whitespace without transliteration."""
74 | text = lowercase(text)
75 | text = collapse_whitespace(text)
76 | return text
77 |
78 |
79 | def transliteration_cleaners(text):
80 | """Pipeline for non-English text that transliterates to ASCII."""
81 | text = convert_to_ascii(text)
82 | text = lowercase(text)
83 | text = collapse_whitespace(text)
84 | return text
85 |
86 |
87 | def english_cleaners(text):
88 | """Pipeline for English text, including number and abbreviation expansion."""
89 | text = convert_to_ascii(text)
90 | text = lowercase(text)
91 | text = expand_numbers(text)
92 | text = expand_abbreviations(text)
93 | text = collapse_whitespace(text)
94 | return text
95 |
96 |
97 | def punctuation_removers(text):
98 | no_punct = ""
99 | for char in text:
100 | if char not in punctuations:
101 | no_punct = no_punct + char
102 | return no_punct
103 |
--------------------------------------------------------------------------------
/dataset/texts/cmudict.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 |
6 | valid_symbols = [
7 | "AA",
8 | "AA0",
9 | "AA1",
10 | "AA2",
11 | "AE",
12 | "AE0",
13 | "AE1",
14 | "AE2",
15 | "AH",
16 | "AH0",
17 | "AH1",
18 | "AH2",
19 | "AO",
20 | "AO0",
21 | "AO1",
22 | "AO2",
23 | "AW",
24 | "AW0",
25 | "AW1",
26 | "AW2",
27 | "AY",
28 | "AY0",
29 | "AY1",
30 | "AY2",
31 | "B",
32 | "CH",
33 | "D",
34 | "DH",
35 | "EH",
36 | "EH0",
37 | "EH1",
38 | "EH2",
39 | "ER",
40 | "ER0",
41 | "ER1",
42 | "ER2",
43 | "EY",
44 | "EY0",
45 | "EY1",
46 | "EY2",
47 | "F",
48 | "G",
49 | "HH",
50 | "IH",
51 | "IH0",
52 | "IH1",
53 | "IH2",
54 | "IY",
55 | "IY0",
56 | "IY1",
57 | "IY2",
58 | "JH",
59 | "K",
60 | "L",
61 | "M",
62 | "N",
63 | "NG",
64 | "OW",
65 | "OW0",
66 | "OW1",
67 | "OW2",
68 | "OY",
69 | "OY0",
70 | "OY1",
71 | "OY2",
72 | "P",
73 | "R",
74 | "S",
75 | "SH",
76 | "T",
77 | "TH",
78 | "UH",
79 | "UH0",
80 | "UH1",
81 | "UH2",
82 | "UW",
83 | "UW0",
84 | "UW1",
85 | "UW2",
86 | "V",
87 | "W",
88 | "Y",
89 | "Z",
90 | "ZH",
91 | ]
92 |
93 | _valid_symbol_set = set(valid_symbols)
94 |
95 |
96 | class CMUDict:
97 | """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
98 |
99 | def __init__(self, file_or_path, keep_ambiguous=True):
100 | if isinstance(file_or_path, str):
101 | with open(file_or_path, encoding="latin-1") as f:
102 | entries = _parse_cmudict(f)
103 | else:
104 | entries = _parse_cmudict(file_or_path)
105 | if not keep_ambiguous:
106 | entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
107 | self._entries = entries
108 |
109 | def __len__(self):
110 | return len(self._entries)
111 |
112 | def lookup(self, word):
113 | """Returns list of ARPAbet pronunciations of the given word."""
114 | return self._entries.get(word.upper())
115 |
116 |
117 | _alt_re = re.compile(r"\([0-9]+\)")
118 |
119 |
120 | def _parse_cmudict(file):
121 | cmudict = {}
122 | for line in file:
123 | if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
124 | parts = line.split(" ")
125 | word = re.sub(_alt_re, "", parts[0])
126 | pronunciation = _get_pronunciation(parts[1])
127 | if pronunciation:
128 | if word in cmudict:
129 | cmudict[word].append(pronunciation)
130 | else:
131 | cmudict[word] = [pronunciation]
132 | return cmudict
133 |
134 |
135 | def _get_pronunciation(s):
136 | parts = s.strip().split(" ")
137 | for part in parts:
138 | if part not in _valid_symbol_set:
139 | return None
140 | return " ".join(parts)
141 |
--------------------------------------------------------------------------------
/dataset/texts/dict_.py:
--------------------------------------------------------------------------------
1 | symbols_ = {
2 | "": 1,
3 | "!": 2,
4 | "'": 3,
5 | ",": 4,
6 | ".": 5,
7 | " ": 6,
8 | "?": 7,
9 | "A": 8,
10 | "B": 9,
11 | "C": 10,
12 | "D": 11,
13 | "E": 12,
14 | "F": 13,
15 | "G": 14,
16 | "H": 15,
17 | "I": 16,
18 | "J": 17,
19 | "K": 18,
20 | "L": 19,
21 | "M": 20,
22 | "N": 21,
23 | "O": 22,
24 | "P": 23,
25 | "Q": 24,
26 | "R": 25,
27 | "S": 26,
28 | "T": 27,
29 | "U": 28,
30 | "V": 29,
31 | "W": 30,
32 | "X": 31,
33 | "Y": 32,
34 | "Z": 33,
35 | "~": 34,
36 | }
37 |
--------------------------------------------------------------------------------
/dataset/texts/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import inflect
4 | import re
5 |
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars" # Unexpected format
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return "%s %s" % (dollars, dollar_unit)
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return "%s %s" % (cents, cent_unit)
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(
60 | num, andword="", zero="oh", group=2
61 | ).replace(", ", " ")
62 | else:
63 | return _inflect.number_to_words(num, andword="")
64 |
65 |
66 | def normalize_numbers(text):
67 | text = re.sub(_comma_number_re, _remove_commas, text)
68 | text = re.sub(_pounds_re, r"\1 pounds", text)
69 | text = re.sub(_dollars_re, _expand_dollars, text)
70 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
71 | text = re.sub(_ordinal_re, _expand_ordinal, text)
72 | text = re.sub(_number_re, _expand_number, text)
73 | return text
74 |
--------------------------------------------------------------------------------
/dataset/texts/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | """
4 | Defines the set of symbols used in text input to the model.
5 |
6 | The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. """
7 |
8 | from dataset.texts import cmudict
9 |
10 | _pad = "_"
11 | _eos = "~"
12 | _bos = "^"
13 | _punctuation = "!'(),.:;? "
14 | _special = "-"
15 | _letters = "abcdefghijklmnopqrstuvwxyz"
16 |
17 | # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
18 | # _arpabet = ['@' + s for s in cmudict.valid_symbols]
19 |
20 | # Export all symbols:
21 | symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + [_eos]
22 |
23 | # For Phonemes
24 |
25 | PAD = "#"
26 | EOS = "~"
27 | PHONEME_CODES = "AA1 AE0 AE1 AH0 AH1 AO0 AO1 AW0 AW1 AY0 AY1 B CH D DH EH0 EH1 EU0 EU1 EY0 EY1 F G HH IH0 IH1 IY0 IY1 JH K L M N NG OW0 OW1 OY0 OY1 P R S SH T TH UH0 UH1 UW0 UW1 V W Y Z ZH pau".split()
28 | _PHONEME_SEP = " "
29 |
30 | phonemes_symbols = [PAD, EOS] + PHONEME_CODES # PAD should be first to have zero id
31 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import configargparse
2 | from dataset import dataloader as loader
3 | from fastspeech import FeedForwardTransformer
4 | import sys
5 | import torch
6 | from dataset.texts import valid_symbols
7 | import os
8 | from utils.hparams import HParam, load_hparam_str
9 | import numpy as np
10 |
11 |
12 | def evaluate(hp, validloader, model):
13 | energy_diff = list()
14 | pitch_diff = list()
15 | dur_diff = list()
16 |
17 | l1 = torch.nn.L1Loss()
18 | model.eval()
19 | for valid in validloader:
20 | x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
21 |
22 | with torch.no_grad():
23 | ilens = torch.tensor([x_[-1].shape[0]], dtype=torch.long, device=x_.device)
24 | _, after_outs, d_outs, e_outs, p_outs = model._forward(x_.cuda(), ilens.cuda(), out_length_.cuda(), dur_.cuda(), es=e_.cuda(), ps=p_.cuda(), is_inference=False) # [T, num_mel]
25 |
26 | # e_orig = model.energy_predictor.to_one_hot(e_).squeeze()
27 | # p_orig = model.pitch_predictor.to_one_hot(p_).squeeze()
28 |
29 | #print(d_outs)
30 |
31 | dur_diff.append(l1(d_outs, dur_.cuda()).item()) #.numpy()
32 | energy_diff.append(l1(e_outs, e_.cuda()).item()) #.numpy()
33 | pitch_diff.append(l1(p_outs, p_.cuda()).item()) #.numpy()
34 |
35 |
36 | '''_, target = read_wav_np( hp.data.wav_dir + f"{ids_[-1]}.wav", sample_rate=hp.audio.sample_rate)
37 | target_pitch = np.load(hp.data.data_dir + f"pitch/{ids_[-1]}.wav" )
38 | target_energy = np.load(hp.data.data_dir + f"energy/{ids_[-1]}.wav" )
39 | '''
40 | model.train()
41 | return np.mean(pitch_diff), np.mean(energy_diff), np.mean(dur_diff)
42 |
43 |
44 | def get_parser():
45 | """Get parser of training arguments."""
46 | parser = configargparse.ArgumentParser(
47 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs",
48 | config_file_parser_class=configargparse.YAMLConfigFileParser,
49 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
50 | )
51 | parser.add_argument(
52 | "-c", "--config", type=str, required=True, help="yaml file for configuration"
53 | )
54 | parser.add_argument(
55 | "-p",
56 | "--checkpoint_path",
57 | type=str,
58 | default=None,
59 | help="path of checkpoint pt to evaluate",
60 | )
61 |
62 | parser.add_argument("--outdir", type=str, required=True, help="Output directory")
63 |
64 | return parser
65 |
66 | def main(cmd_args):
67 | """Run training."""
68 | parser = get_parser()
69 | args, _ = parser.parse_known_args(cmd_args)
70 | args = parser.parse_args(cmd_args)
71 |
72 | if os.path.exists(args.checkpoint_path):
73 | checkpoint = torch.load(args.checkpoint_path)
74 | else:
75 | print("Checkpoint not exixts")
76 | return None
77 |
78 | if args.config is not None:
79 | hp = HParam(args.config)
80 | else:
81 | hp = load_hparam_str(checkpoint["hp_str"])
82 |
83 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)
84 | print("Checkpoint : ", args.checkpoint_path)
85 |
86 |
87 |
88 | idim = len(valid_symbols)
89 | odim = hp.audio.num_mels
90 | model = FeedForwardTransformer(
91 | idim, odim, hp
92 | )
93 | # os.makedirs(args.out, exist_ok=True)
94 | checkpoint = torch.load(args.checkpoint_path)
95 | model.load_state_dict(checkpoint["model"])
96 |
97 | evaluate(hp, validloader, model)
98 |
99 |
100 | if __name__ == "__main__":
101 | main(sys.argv[1:])
102 |
--------------------------------------------------------------------------------
/export_torchscript.py:
--------------------------------------------------------------------------------
1 | from utils.hparams import HParam
2 | from dataset.texts import valid_symbols
3 | import utils.fastspeech2_script as fs2
4 | import configargparse
5 | import torch
6 | import sys
7 |
8 |
9 | def get_parser():
10 |
11 | parser = configargparse.ArgumentParser(
12 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs",
13 | config_file_parser_class=configargparse.YAMLConfigFileParser,
14 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
15 | )
16 |
17 | parser.add_argument(
18 | "-c", "--config", type=str, required=True, help="yaml file for configuration"
19 | )
20 | parser.add_argument(
21 | "-n",
22 | "--name",
23 | type=str,
24 | required=True,
25 | help="name of the model for logging, saving checkpoint",
26 | )
27 | parser.add_argument("--outdir", type=str, required=True, help="Output directory")
28 | parser.add_argument(
29 | "-t", "--trace", action="store_true", help="For JIT Trace Module"
30 | )
31 |
32 | return parser
33 |
34 |
35 | def main(cmd_args):
36 |
37 | parser = get_parser()
38 | args, _ = parser.parse_known_args(cmd_args)
39 |
40 | args = parser.parse_args(cmd_args)
41 |
42 | hp = HParam(args.config)
43 |
44 | idim = len(valid_symbols)
45 | odim = hp.audio.num_mels
46 | model = fs2.FeedForwardTransformer(idim, odim, hp)
47 | my_script_module = torch.jit.script(model)
48 | print("Scripting")
49 | my_script_module.save("{}/{}.pt".format(args.outdir, args.name))
50 | print("Script done")
51 | if args.trace:
52 | print("Tracing")
53 | model.eval()
54 | with torch.no_grad():
55 | my_trace_module = torch.jit.trace(
56 | model, torch.ones(50).to(dtype=torch.int64)
57 | )
58 | my_trace_module.save("{}/trace_{}.pt".format(args.outdir, args.name))
59 | print("Trace Done")
60 |
61 |
62 | if __name__ == "__main__":
63 | main(sys.argv[1:])
64 |
--------------------------------------------------------------------------------
/fastspeech.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Tomoki Hayashi
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """FastSpeech related loss."""
8 |
9 | import logging
10 |
11 | import torch
12 | from core.duration_modeling.duration_predictor import DurationPredictor
13 | from core.duration_modeling.duration_predictor import DurationPredictorLoss
14 | from core.variance_predictor import EnergyPredictor, EnergyPredictorLoss
15 | from core.variance_predictor import PitchPredictor, PitchPredictorLoss
16 | from core.duration_modeling.length_regulator import LengthRegulator
17 | from utils.util import make_non_pad_mask
18 | from utils.util import make_pad_mask
19 | from core.embedding import PositionalEncoding
20 | from core.embedding import ScaledPositionalEncoding
21 | from core.encoder import Encoder
22 | from core.modules import initialize
23 | from core.modules import Postnet
24 | from typeguard import check_argument_types
25 | from typing import Dict, Tuple, Sequence
26 |
27 |
28 | class FeedForwardTransformer(torch.nn.Module):
29 | """Feed Forward Transformer for TTS a.k.a. FastSpeech.
30 | This is a module of FastSpeech, feed-forward Transformer with duration predictor described in
31 | `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive
32 | processing during inference, resulting in fast decoding compared with auto-regressive Transformer.
33 | .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
34 | https://arxiv.org/pdf/1905.09263.pdf
35 | """
36 |
37 | def __init__(self, idim: int, odim: int, hp: Dict):
38 | """Initialize feed-forward Transformer module.
39 | Args:
40 | idim (int): Dimension of the inputs.
41 | odim (int): Dimension of the outputs.
42 | """
43 | # initialize base classes
44 | assert check_argument_types()
45 | torch.nn.Module.__init__(self)
46 |
47 | # fill missing arguments
48 |
49 | # store hyperparameters
50 | self.idim = idim
51 | self.odim = odim
52 |
53 | self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
54 | self.use_masking = hp.model.use_masking
55 |
56 | # use idx 0 as padding idx
57 | padding_idx = 0
58 |
59 | # get positional encoding class
60 | pos_enc_class = (
61 | ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
62 | )
63 |
64 | # define encoder
65 | encoder_input_layer = torch.nn.Embedding(
66 | num_embeddings=idim, embedding_dim=hp.model.adim, padding_idx=padding_idx
67 | )
68 | self.encoder = Encoder(
69 | idim=idim,
70 | attention_dim=hp.model.adim,
71 | attention_heads=hp.model.aheads,
72 | linear_units=hp.model.eunits,
73 | num_blocks=hp.model.elayers,
74 | input_layer=encoder_input_layer,
75 | dropout_rate=0.2,
76 | positional_dropout_rate=0.2,
77 | attention_dropout_rate=0.2,
78 | pos_enc_class=pos_enc_class,
79 | normalize_before=hp.model.encoder_normalize_before,
80 | concat_after=hp.model.encoder_concat_after,
81 | positionwise_layer_type=hp.model.positionwise_layer_type,
82 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size,
83 | )
84 |
85 | self.duration_predictor = DurationPredictor(
86 | idim=hp.model.adim,
87 | n_layers=hp.model.duration_predictor_layers,
88 | n_chans=hp.model.duration_predictor_chans,
89 | kernel_size=hp.model.duration_predictor_kernel_size,
90 | dropout_rate=hp.model.duration_predictor_dropout_rate,
91 | )
92 |
93 | self.energy_predictor = EnergyPredictor(
94 | idim=hp.model.adim,
95 | n_layers=hp.model.duration_predictor_layers,
96 | n_chans=hp.model.duration_predictor_chans,
97 | kernel_size=hp.model.duration_predictor_kernel_size,
98 | dropout_rate=hp.model.duration_predictor_dropout_rate,
99 | min=hp.data.e_min,
100 | max=hp.data.e_max,
101 | )
102 | self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)
103 |
104 | self.pitch_predictor = PitchPredictor(
105 | idim=hp.model.adim,
106 | n_layers=hp.model.duration_predictor_layers,
107 | n_chans=hp.model.duration_predictor_chans,
108 | kernel_size=hp.model.duration_predictor_kernel_size,
109 | dropout_rate=hp.model.duration_predictor_dropout_rate,
110 | min=hp.data.p_min,
111 | max=hp.data.p_max,
112 | )
113 | self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)
114 |
115 | # define length regulator
116 | self.length_regulator = LengthRegulator()
117 |
118 | # define decoder
119 | # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
120 | self.decoder = Encoder(
121 | idim=hp.model.adim,
122 | attention_dim=hp.model.ddim,
123 | attention_heads=hp.model.aheads,
124 | linear_units=hp.model.dunits,
125 | num_blocks=hp.model.dlayers,
126 | input_layer="linear",
127 | dropout_rate=0.2,
128 | positional_dropout_rate=0.2,
129 | attention_dropout_rate=0.2,
130 | pos_enc_class=pos_enc_class,
131 | normalize_before=hp.model.decoder_normalize_before,
132 | concat_after=hp.model.decoder_concat_after,
133 | positionwise_layer_type=hp.model.positionwise_layer_type,
134 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size,
135 | )
136 |
137 | # define postnet
138 | self.postnet = (
139 | None
140 | if hp.model.postnet_layers == 0
141 | else Postnet(
142 | idim=idim,
143 | odim=odim,
144 | n_layers=hp.model.postnet_layers,
145 | n_chans=hp.model.postnet_chans,
146 | n_filts=hp.model.postnet_filts,
147 | use_batch_norm=hp.model.use_batch_norm,
148 | dropout_rate=hp.model.postnet_dropout_rate,
149 | )
150 | )
151 |
152 | # define final projection
153 | self.feat_out = torch.nn.Linear(hp.model.ddim, odim * hp.model.reduction_factor)
154 |
155 | # initialize parameters
156 | self._reset_parameters(
157 | init_type=hp.model.transformer_init,
158 | init_enc_alpha=hp.model.initial_encoder_alpha,
159 | init_dec_alpha=hp.model.initial_decoder_alpha,
160 | )
161 |
162 | # define criterions
163 | self.duration_criterion = DurationPredictorLoss()
164 | self.energy_criterion = EnergyPredictorLoss()
165 | self.pitch_criterion = PitchPredictorLoss()
166 | self.criterion = torch.nn.L1Loss(reduction="mean")
167 | self.use_weighted_masking = hp.model.use_weighted_masking
168 |
169 | def _forward(
170 | self,
171 | xs: torch.Tensor,
172 | ilens: torch.Tensor,
173 | olens: torch.Tensor = None,
174 | ds: torch.Tensor = None,
175 | es: torch.Tensor = None,
176 | ps: torch.Tensor = None,
177 | is_inference: bool = False,
178 | ) -> Sequence[torch.Tensor]:
179 | # forward encoder
180 | x_masks = self._source_mask(
181 | ilens
182 | ) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])
183 |
184 | hs, _ = self.encoder(
185 | xs, x_masks
186 | ) # (B, Tmax, adim) -> torch.Size([32, 121, 256])
187 | # print("ys :", ys.shape)
188 |
189 | # forward duration predictor and length regulator
190 | d_masks = make_pad_mask(ilens).to(xs.device)
191 |
192 | if is_inference:
193 | d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax)
194 | hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim)
195 | one_hot_energy = self.energy_predictor.inference(hs) # (B, Lmax, adim)
196 | one_hot_pitch = self.pitch_predictor.inference(hs) # (B, Lmax, adim)
197 | else:
198 | with torch.no_grad():
199 | # ds = self.duration_calculator(xs, ilens, ys, olens) # (B, Tmax)
200 | one_hot_energy = self.energy_predictor.to_one_hot(
201 | es
202 | ) # (B, Lmax, adim) torch.Size([32, 868, 256])
203 | # print("one_hot_energy:", one_hot_energy.shape)
204 | one_hot_pitch = self.pitch_predictor.to_one_hot(
205 | ps
206 | ) # (B, Lmax, adim) torch.Size([32, 868, 256])
207 | # print("one_hot_pitch:", one_hot_pitch.shape)
208 | mel_masks = make_pad_mask(olens).to(xs.device)
209 | # print("Before Hs:", hs.shape) # torch.Size([32, 121, 256])
210 | d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax)
211 | # print("d_outs:", d_outs.shape) # torch.Size([32, 121])
212 | hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim)
213 | # print("After Hs:",hs.shape) #torch.Size([32, 868, 256])
214 | e_outs = self.energy_predictor(hs, mel_masks)
215 | # print("e_outs:", e_outs.shape) #torch.Size([32, 868])
216 | p_outs = self.pitch_predictor(hs, mel_masks)
217 | # print("p_outs:", p_outs.shape) #torch.Size([32, 868])
218 | hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim)
219 | hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim)
220 | # forward decoder
221 | if olens is not None:
222 | h_masks = self._source_mask(olens)
223 | else:
224 | h_masks = None
225 |
226 | zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim)
227 |
228 | before_outs = self.feat_out(zs).view(
229 | zs.size(0), -1, self.odim
230 | ) # (B, Lmax, odim)
231 |
232 | # postnet -> (B, Lmax//r * r, odim)
233 | if self.postnet is None:
234 | after_outs = before_outs
235 | else:
236 | after_outs = before_outs + self.postnet(
237 | before_outs.transpose(1, 2)
238 | ).transpose(1, 2)
239 |
240 | if is_inference:
241 | return before_outs, after_outs, d_outs, one_hot_energy, one_hot_pitch
242 | else:
243 | return before_outs, after_outs, d_outs, e_outs, p_outs
244 |
245 | def forward(
246 | self,
247 | xs: torch.Tensor,
248 | ilens: torch.Tensor,
249 | ys: torch.Tensor,
250 | olens: torch.Tensor,
251 | ds: torch.Tensor,
252 | es: torch.Tensor,
253 | ps: torch.Tensor,
254 | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
255 | """Calculate forward propagation.
256 | Args:
257 | xs (Tensor): Batch of padded character ids (B, Tmax).
258 | ilens (LongTensor): Batch of lengths of each input batch (B,).
259 | ys (Tensor): Batch of padded target features (B, Lmax, odim).
260 | olens (LongTensor): Batch of the lengths of each target (B,).
261 | spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
262 | Returns:
263 | Tensor: Loss value.
264 | """
265 | # remove unnecessary padded part (for multi-gpus)
266 | xs = xs[:, : max(ilens)] # torch.Size([32, 121]) -> [B, Tmax]
267 | ys = ys[:, : max(olens)] # torch.Size([32, 868, 80]) -> [B, Lmax, odim]
268 |
269 | # forward propagation
270 | before_outs, after_outs, d_outs, e_outs, p_outs = self._forward(
271 | xs, ilens, olens, ds, es, ps, is_inference=False
272 | )
273 |
274 | # modifiy mod part of groundtruth
275 | # if hp.model.reduction_factor > 1:
276 | # olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
277 | # max_olen = max(olens)
278 | # ys = ys[:, :max_olen]
279 |
280 | # apply mask to remove padded part
281 | if self.use_masking:
282 | in_masks = make_non_pad_mask(ilens).to(xs.device)
283 | d_outs = d_outs.masked_select(in_masks)
284 | ds = ds.masked_select(in_masks)
285 | out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
286 | mel_masks = make_non_pad_mask(olens).to(ys.device)
287 | before_outs = before_outs.masked_select(out_masks)
288 | es = es.masked_select(mel_masks) # Write size
289 | ps = ps.masked_select(mel_masks) # Write size
290 | e_outs = e_outs.masked_select(mel_masks) # Write size
291 | p_outs = p_outs.masked_select(mel_masks) # Write size
292 | after_outs = (
293 | after_outs.masked_select(out_masks) if after_outs is not None else None
294 | )
295 | ys = ys.masked_select(out_masks)
296 |
297 | # calculate loss
298 | before_loss = self.criterion(before_outs, ys)
299 | after_loss = 0
300 | if after_outs is not None:
301 | after_loss = self.criterion(after_outs, ys)
302 | l1_loss = before_loss + after_loss
303 | duration_loss = self.duration_criterion(d_outs, ds)
304 | energy_loss = self.energy_criterion(e_outs, es)
305 | pitch_loss = self.pitch_criterion(p_outs, ps)
306 |
307 | # make weighted mask and apply it
308 | if self.use_weighted_masking:
309 | out_masks = make_non_pad_mask(olens).unsqueeze(-1).to(ys.device)
310 | out_weights = out_masks.float() / out_masks.sum(dim=1, keepdim=True).float()
311 | out_weights /= ys.size(0) * ys.size(2)
312 | duration_masks = make_non_pad_mask(ilens).to(ys.device)
313 | duration_weights = (
314 | duration_masks.float() / duration_masks.sum(dim=1, keepdim=True).float()
315 | )
316 | duration_weights /= ds.size(0)
317 |
318 | # apply weight
319 | l1_loss = l1_loss.mul(out_weights).masked_select(out_masks).sum()
320 | duration_loss = (
321 | duration_loss.mul(duration_weights).masked_select(duration_masks).sum()
322 | )
323 |
324 | loss = l1_loss + duration_loss + energy_loss + pitch_loss
325 | report_keys = [
326 | {"l1_loss": l1_loss.item()},
327 | {"before_loss": before_loss.item()},
328 | {"after_loss": after_loss.item()},
329 | {"duration_loss": duration_loss.item()},
330 | {"energy_loss": energy_loss.item()},
331 | {"pitch_loss": pitch_loss.item()},
332 | {"loss": loss.item()},
333 | ]
334 |
335 | # self.reporter.report(report_keys)
336 |
337 | return loss, report_keys
338 |
339 | def inference(self, x: torch.Tensor) -> torch.Tensor:
340 | """Generate the sequence of features given the sequences of characters.
341 | Args:
342 | x (Tensor): Input sequence of characters (T,).
343 | inference_args (Namespace): Dummy for compatibility.
344 | spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
345 | Returns:
346 | Tensor: Output sequence of features (1, L, odim).
347 | None: Dummy for compatibility.
348 | None: Dummy for compatibility.
349 | """
350 | # setup batch axis
351 | ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
352 | xs = x.unsqueeze(0)
353 |
354 | # inference
355 | _, outs, _, _, _ = self._forward(xs, ilens, is_inference=True) # (L, odim)
356 |
357 | return outs[0]
358 |
359 | def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
360 | """Make masks for self-attention.
361 | Examples:
362 | >>> ilens = [5, 3]
363 | >>> self._source_mask(ilens)
364 | tensor([[[1, 1, 1, 1, 1],
365 | [1, 1, 1, 1, 1],
366 | [1, 1, 1, 1, 1],
367 | [1, 1, 1, 1, 1],
368 | [1, 1, 1, 1, 1]],
369 | [[1, 1, 1, 0, 0],
370 | [1, 1, 1, 0, 0],
371 | [1, 1, 1, 0, 0],
372 | [0, 0, 0, 0, 0],
373 | [0, 0, 0, 0, 0]]], dtype=torch.uint8)
374 | """
375 | x_masks = make_non_pad_mask(ilens).to(device=next(self.parameters()).device)
376 | return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
377 |
378 | def _reset_parameters(
379 | self, init_type: str, init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0
380 | ):
381 | # initialize parameters
382 | initialize(self, init_type)
383 |
384 | # initialize alpha in scaled positional encoding
385 | if self.use_scaled_pos_enc:
386 | self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
387 | self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
388 |
--------------------------------------------------------------------------------
/filelists/valid_filelist.txt:
--------------------------------------------------------------------------------
1 | printing in the only sense with which we are at present concerned differs from most if not from all the arts and crafts represented in the exhibition|-2 1 4 14 16 22 35 56 66 78 84 87 97 106 114 117 126 140 150 158 166 171 174 179 182 187 195 203 209 218 227 234 239 248 252 255 261 268 271 274 282 286 290 296 308 329 334 341 377 384 390 399 413 431 433 439 441 444 452 457 472 481 487 490 500 507 516 525 529 532 535 542 566 571 576 584 592 596 605 610 614 617 619 630 633 647 656 662 670 679 683 690 693 696 705 709 712 718 723 727 742 745 748 757 762 770 777 781 787 796 805 811 829|3 3 10 2 6 13 21 10 12 6 3 10 9 8 3 9 14 10 8 8 5 3 5 3 5 8 8 6 9 9 7 5 9 4 3 6 7 3 3 8 4 4 6 12 21 5 7 36 7 6 9 14 18 2 6 2 3 8 5 15 9 6 3 10 7 9 9 4 3 3 7 24 5 5 8 8 4 9 5 4 3 2 11 3 14 9 6 8 9 4 7 3 3 9 4 3 6 5 4 15 3 3 9 5 8 7 4 6 9 9 6 18|P R IH1 N T IH0 NG pau IH1 N DH IY0 OW1 N L IY0 S EH1 N S W IH1 DH pau W IH1 CH W IY1 AA1 R AE1 T P R EH1 Z AH0 N T K AH0 N S ER1 N D pau D IH1 F ER0 Z pau F R AH1 M M OW1 S T IH1 F N AA1 T F R AH1 M AO1 L DH IY0 AA1 R T S AH0 N D K R AE1 F T S R EH2 P R IH0 Z EH1 N T IH0 D IH1 N DH IY0 EH2 K S AH0 B IH1 SH AH0 N|LJ001-0001.wav
2 | in being comparatively modern |-2 4 10 14 25 28 32 38 41 48 58 65 71 75 83 88 93 103 107 118 132 137 142 155 159 161|6 6 4 11 3 4 6 3 7 10 7 6 4 8 5 5 10 4 11 14 5 5 13 4 2|IH1 N B IY1 IH0 NG K AH0 M P EH1 R AH0 T IH0 V L IY0 M AA1 D ER0 N pau sil|LJ001-0002.wav
3 | for although the chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the netherlands by a similar process|-2 2 8 15 21 26 30 43 46 48 52 66 76 82 97 109 115 119 125 135 141 146 150 156 165 170 180 196 199 207 210 213 221 230 234 244 247 254 271 282 297 313 329 335 339 347 362 367 375 380 386 389 394 402 413 423 426 433 439 451 458 463 471 479 487 493 497 503 512 517 523 526 532 540 545 552 561 569 575 585 593 598 606 608 612 619 628 633 642 650 656 663 668 677 701 709 718 727 737 741 746 749 759 763 773 778 789 800 810 830|4 6 7 6 5 4 13 3 2 4 14 10 6 15 12 6 4 6 10 6 5 4 6 9 5 10 16 3 8 3 3 8 9 4 10 3 7 17 11 15 16 16 6 4 8 15 5 8 5 6 3 5 8 11 10 3 7 6 12 7 5 8 8 8 6 4 6 9 5 6 3 6 8 5 7 9 8 6 10 8 5 8 2 4 7 9 5 9 8 6 7 5 9 24 8 9 9 10 4 5 3 10 4 10 5 11 11 10 20|F AO1 R AO2 L DH OW1 pau DH AH1 CH AY0 N IY1 Z T UH1 K IH0 M P R EH1 SH AH0 N Z pau F R AH1 M W UH1 D B L AA1 K S pau IH0 N G R EY1 V D IH1 N R IH0 L IY1 F pau F ER0 S EH1 N CH ER0 IY0 Z B IH0 F AO1 R DH AH0 W UH1 D K AH2 T ER0 Z AH1 V DH AH0 N EH1 DH ER0 L AH0 N D Z pau B AY1 AH0 S IH1 M AH0 L ER0 P R AA1 S EH2 S|LJ001-0003.wav
4 | produced the block books which were the immediate predecessors of the true printed book |-2 1 4 8 15 27 32 47 50 54 60 66 80 88 92 109 116 135 147 153 159 165 171 179 183 196 201 211 221 223 233 235 245 249 255 259 263 268 278 286 297 305 312 317 324 327 331 341 348 361 368 371 377 381 387 391 397 406 426 437 440|3 3 4 7 12 5 15 3 4 6 6 14 8 4 17 7 19 12 6 6 6 6 8 4 13 5 10 10 2 10 2 10 4 6 4 4 5 10 8 11 8 7 5 7 3 4 10 7 13 7 3 6 4 6 4 6 9 20 11 3|P R AH0 D UW1 S T DH AH0 B L AA1 K B UH1 K S pau W IH1 CH W ER1 DH IY0 IH0 M IY1 D IY0 AH0 T P R EH1 D AH0 S EH2 S ER0 Z AH0 V DH AH0 T R UW1 P R IH1 N T AH0 D B UH1 K pau|LJ001-0004.wav
5 | the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing |-2 0 9 11 16 21 29 34 43 46 51 56 63 72 79 83 87 91 94 103 112 120 124 127 135 143 152 161 172 181 186 189 191 194 202 207 209 214 221 224 231 234 239 246 252 260 267 276 280 287 296 303 308 317 325 343 362 370 380 390 399 407 413 418 424 430 439 446 451 456 464 471 478 485 495 517 530 537 541 550 555 560 566 574 578 587 590 594 599 603 606 614 625 631 634 637 646 653 656 662 668 674 680 693 696|2 9 2 5 5 8 5 9 3 5 5 7 9 7 4 4 4 3 9 9 8 4 3 8 8 9 9 11 9 5 3 2 3 8 5 2 5 7 3 7 3 5 7 6 8 7 9 4 7 9 7 5 9 8 18 19 8 10 10 9 8 6 5 6 6 9 7 5 5 8 7 7 7 10 22 13 7 4 9 5 5 6 8 4 9 3 4 5 4 3 8 11 6 3 3 9 7 3 6 6 6 6 13 3|DH IY0 IH0 N V EH1 N SH AH0 N AH0 V M UW1 V AH0 B AH0 L M EH1 T AH0 L L EH1 T ER0 Z IH0 N DH AH0 M IH1 D AH0 L AH1 V DH AH0 F IH0 F T IY1 N TH S EH1 N CH ER0 IY0 pau M EY1 JH AH1 S T L IY0 B IY1 K AH0 N S IH1 D ER0 D pau AE1 Z DH IY0 IH0 N V EH1 N SH AH0 N AH0 V DH IY0 AA1 R T AH0 V P R IH1 N T IH0 NG pau|LJ001-0005.wav
6 | and it is worth mention in passing that as an example of fine typography |-2 18 22 34 40 53 56 65 72 83 91 100 105 111 116 126 132 137 145 151 162 176 185 194 210 236 242 254 261 264 277 283 285 291 295 302 311 326 332 334 337 355 358 369 376 391 398 405 415 423 434 442 446 452 461 484 487|20 4 12 6 13 3 9 7 11 8 9 5 6 5 10 6 5 8 6 11 14 9 9 16 26 6 12 7 3 13 6 2 6 4 7 9 15 6 2 3 18 3 11 7 15 7 7 10 8 11 8 4 6 9 23 3|AE1 N D pau IH1 T IH1 Z W ER1 TH M EH1 N SH AH0 N IH1 N P AE1 S IH0 NG pau DH AE1 T pau AE1 Z AE1 N IH0 G Z AE1 M P AH0 L AH1 V F AY1 N T AH0 P AA1 G R AH0 F IY0 pau|LJ001-0006.wav
7 | the earliest book printed with movable types the gutenberg or forty two line bible of about fourteen fifty five |-2 0 10 27 35 45 49 55 63 68 87 103 109 112 117 122 128 134 139 142 148 157 164 172 176 182 186 192 199 209 227 235 252 273 277 287 295 304 310 317 320 326 346 358 363 377 394 396 404 409 414 419 427 436 446 461 476 487 493 508 513 516 535 544 554 558 563 569 582 594 600 605 609 623 633 639 646 652 659 666 672 684 707 718 720|2 10 17 8 10 4 6 8 5 19 16 6 3 5 5 6 6 5 3 6 9 7 8 4 6 4 6 7 10 18 8 17 21 4 10 8 9 6 7 3 6 20 12 5 14 17 2 8 5 5 5 8 9 10 15 15 11 6 15 5 3 19 9 10 4 5 6 13 12 6 5 4 14 10 6 7 6 7 7 6 12 23 11 2|DH IY0 ER1 L IY0 AH0 S T B UH1 K P R IH1 N T IH0 D W IH1 TH M UW1 V AH0 B AH0 L T AY1 P S pau DH IY0 G UW1 T AH0 N B ER0 G pau AO1 R pau F AO1 R T IY0 T UW1 L AY1 N B AY1 B AH0 L pau AH1 V AH0 B AW1 T F AO1 R T IY1 N F IH1 F T IY0 F AY1 V pau|LJ001-0007.wav
8 | has never been surpassed |-2 0 4 13 21 29 35 42 48 57 63 75 80 91 118 134 147 150 151|2 4 9 8 8 6 7 6 9 6 12 5 11 27 16 13 3 1|HH AE1 Z N EH1 V ER0 B IH1 N S ER0 P AE1 S T pau sil|LJ001-0008.wav
9 | printing then for our purpose may be considered as the art of making books by means of movable types|-2 0 3 15 17 22 30 42 47 64 73 82 84 91 110 116 128 139 144 157 181 202 209 217 221 229 237 241 246 258 265 270 278 283 292 300 303 314 329 335 346 352 358 366 374 382 385 399 405 420 426 452 467 474 487 494 507 513 520 528 537 547 556 561 566 571 575 584 598 612 624 648|2 3 12 2 5 8 12 5 17 9 9 2 7 19 6 12 11 5 13 24 21 7 8 4 8 8 4 5 12 7 5 8 5 9 8 3 11 15 6 11 6 6 8 8 8 3 14 6 15 6 26 15 7 13 7 13 6 7 8 9 10 9 5 5 5 4 9 14 14 12 24|P R IH1 N T IH0 NG DH EH1 N F AO1 R AW1 ER0 P ER1 P AH0 S pau M EY1 B IY1 K AH0 N S IH1 D ER0 D EH1 Z DH IY0 AA1 R T AH1 V M EY1 K IH0 NG B UH1 K S pau B AY1 M IY1 N Z AH0 V M UW1 V AH0 B AH0 L T AY1 P S|LJ001-0009.wav
10 | now as all books not primarily intended as picture books consist principally of types composed to form letterpress |-2 7 41 66 83 93 115 123 128 142 147 163 165 176 190 199 202 209 214 221 227 232 236 244 256 259 266 273 281 285 290 296 302 308 319 321 324 331 338 345 351 359 375 381 397 420 427 432 438 446 453 461 469 475 477 482 488 493 497 507 513 525 530 537 545 558 566 574 581 585 589 597 614 623 631 636 643 647 658 665 672 675 683 690 696 702 711 718 730 752 757|9 34 25 17 10 22 8 5 14 5 16 2 11 14 9 3 7 5 7 6 5 4 8 12 3 7 7 8 4 5 6 6 6 11 2 3 7 7 7 6 8 16 6 16 23 7 5 6 8 7 8 8 6 2 5 6 5 4 10 6 12 5 7 8 13 8 8 7 4 4 8 17 9 8 5 7 4 11 7 7 3 8 7 6 6 9 7 12 22 5|N AW1 pau AE1 Z AO1 L B UH1 K S pau N AA1 T P R AY0 M EH1 R AH0 L IY0 IH0 N T EH1 N D IH0 D EH1 Z pau P IH1 K CH ER0 B UH1 K S pau K AH0 N S IH1 S T P R IH1 N S IH0 P L IY0 AH0 V T AY1 P S K AH0 M P OW1 Z D pau T AH0 F AO1 R M L EH1 T ER0 P R EH2 S pau|LJ001-0010.wav
11 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | """TTS Inference script."""
2 |
3 | import configargparse
4 | import logging
5 | import os
6 | import torch
7 | import sys
8 | from utils.util import set_deterministic_pytorch
9 | from fastspeech import FeedForwardTransformer
10 | from dataset.texts import phonemes_to_sequence
11 | import time
12 | from dataset.audio_processing import griffin_lim
13 | import numpy as np
14 | from utils.stft import STFT
15 | from scipy.io.wavfile import write
16 | from dataset.texts import valid_symbols
17 | from utils.hparams import HParam, load_hparam_str
18 | from dataset.texts.cleaners import english_cleaners, punctuation_removers
19 | import matplotlib.pyplot as plt
20 | from g2p_en import G2p
21 |
22 |
23 | def synthesis(args, text, hp):
24 | """Decode with E2E-TTS model."""
25 | set_deterministic_pytorch(args)
26 | # read training config
27 | idim = hp.symbol_len
28 | odim = hp.num_mels
29 | model = FeedForwardTransformer(idim, odim, hp)
30 | print(model)
31 |
32 | if os.path.exists(args.path):
33 | print("\nSynthesis Session...\n")
34 | model.load_state_dict(torch.load(args.path), strict=False)
35 | else:
36 | print("Checkpoint not exixts")
37 | return None
38 |
39 | model.eval()
40 |
41 | # set torch device
42 | device = torch.device("cuda" if args.ngpu > 0 else "cpu")
43 | model = model.to(device)
44 |
45 | input = np.asarray(phonemes_to_sequence(text.split()))
46 | text = torch.LongTensor(input)
47 | text = text.cuda()
48 | # [num_char]
49 |
50 | with torch.no_grad():
51 | # decode and write
52 | idx = input[:5]
53 | start_time = time.time()
54 | print("text :", text.size())
55 | outs, probs, att_ws = model.inference(text, hp)
56 | print("Out size : ", outs.size())
57 |
58 | logging.info(
59 | "inference speed = %s msec / frame."
60 | % ((time.time() - start_time) / (int(outs.size(0)) * 1000))
61 | )
62 | if outs.size(0) == text.size(0) * args.maxlenratio:
63 | logging.warning("output length reaches maximum length .")
64 |
65 | print("mels", outs.size())
66 | mel = outs.cpu().numpy() # [T_out, num_mel]
67 | print("numpy ", mel.shape)
68 |
69 | return mel
70 |
71 |
72 | ### for direct text/para input ###
73 |
74 |
75 | g2p = G2p()
76 |
77 |
78 | def plot_mel(mels):
79 | melspec = mels.reshape(1, 80, -1)
80 | plt.imshow(melspec.detach().cpu()[0], aspect="auto", origin="lower")
81 | plt.savefig("mel.png")
82 |
83 |
84 | def preprocess(text):
85 |
86 | # input - line of text
87 | # output - list of phonemes
88 | str1 = " "
89 | clean_content = english_cleaners(text)
90 | clean_content = punctuation_removers(clean_content)
91 | phonemes = g2p(clean_content)
92 |
93 | phonemes = ["" if x == " " else x for x in phonemes]
94 | phonemes = ["pau" if x == "," else x for x in phonemes]
95 | phonemes = ["pau" if x == "." else x for x in phonemes]
96 | phonemes = str1.join(phonemes)
97 |
98 | return phonemes
99 |
100 |
101 | def process_paragraph(para):
102 | # input - paragraph with lines seperated by "."
103 | # output - list with each item as lines of paragraph seperated by suitable padding
104 | text = []
105 | for lines in para.split("."):
106 | text.append(lines)
107 |
108 | return text
109 |
110 |
111 | def synth(text, model, hp):
112 | """Decode with E2E-TTS model."""
113 |
114 | print("TTS synthesis")
115 |
116 | model.eval()
117 | # set torch device
118 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")
119 | model = model.to(device)
120 |
121 | input = np.asarray(phonemes_to_sequence(text))
122 |
123 | text = torch.LongTensor(input)
124 | text = text.to(device)
125 |
126 | with torch.no_grad():
127 | print("predicting")
128 | outs = model.inference(text) # model(text) for jit script
129 | mel = outs
130 | return mel
131 |
132 |
133 | def main(args):
134 | """Run deocding."""
135 | para_mel = []
136 | parser = get_parser()
137 | args = parser.parse_args(args)
138 |
139 | logging.info("python path = " + os.environ.get("PYTHONPATH", "(None)"))
140 |
141 | print("Text : ", args.text)
142 | print("Checkpoint : ", args.checkpoint_path)
143 | if os.path.exists(args.checkpoint_path):
144 | checkpoint = torch.load(args.checkpoint_path)
145 | else:
146 | logging.info("Checkpoint not exixts")
147 | return None
148 |
149 | if args.config is not None:
150 | hp = HParam(args.config)
151 | else:
152 | hp = load_hparam_str(checkpoint["hp_str"])
153 |
154 | idim = len(valid_symbols)
155 | odim = hp.audio.num_mels
156 | model = FeedForwardTransformer(
157 | idim, odim, hp
158 | ) # torch.jit.load("./etc/fastspeech_scrip_new.pt")
159 |
160 | os.makedirs(args.out, exist_ok=True)
161 | if args.old_model:
162 | logging.info("\nSynthesis Session...\n")
163 | model.load_state_dict(checkpoint, strict=False)
164 | else:
165 | checkpoint = torch.load(args.checkpoint_path)
166 | model.load_state_dict(checkpoint["model"])
167 |
168 | text = process_paragraph(args.text)
169 |
170 | for i in range(0, len(text)):
171 | txt = preprocess(text[i])
172 | audio = synth(txt, model, hp)
173 | m = audio.T
174 | para_mel.append(m)
175 |
176 | m = torch.cat(para_mel, dim=1)
177 | np.save("mel.npy", m.cpu().numpy())
178 | plot_mel(m)
179 |
180 | if hp.train.melgan_vocoder:
181 | m = m.unsqueeze(0)
182 | print("Mel shape: ", m.shape)
183 | vocoder = torch.hub.load("seungwonpark/melgan", "melgan")
184 | vocoder.eval()
185 | if torch.cuda.is_available():
186 | vocoder = vocoder.cuda()
187 | mel = m.cuda()
188 |
189 | with torch.no_grad():
190 | wav = vocoder.inference(
191 | mel
192 | ) # mel ---> batch, num_mels, frames [1, 80, 234]
193 | wav = wav.cpu().float().numpy()
194 | else:
195 | stft = STFT(filter_length=1024, hop_length=256, win_length=1024)
196 | print(m.size())
197 | m = m.unsqueeze(0)
198 | wav = griffin_lim(m, stft, 30)
199 | wav = wav.cpu().numpy()
200 | save_path = "{}/test_tts.wav".format(args.out)
201 | write(save_path, hp.audio.sample_rate, wav.astype("int16"))
202 |
203 |
204 | # NOTE: you need this func to generate our sphinx doc
205 | def get_parser():
206 | """Get parser of decoding arguments."""
207 | parser = configargparse.ArgumentParser(
208 | description="Synthesize speech from text using a TTS model on one CPU",
209 | config_file_parser_class=configargparse.YAMLConfigFileParser,
210 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
211 | )
212 | # general configuration
213 |
214 | parser.add_argument(
215 | "-c", "--config", type=str, required=True, help="yaml file for configuration"
216 | )
217 | parser.add_argument(
218 | "-p",
219 | "--checkpoint_path",
220 | type=str,
221 | default=None,
222 | help="path of checkpoint pt file to resume training",
223 | )
224 | parser.add_argument("--out", type=str, required=True, help="Output filename")
225 | parser.add_argument(
226 | "-o", "--old_model", action="store_true", help="Resume Old model "
227 | )
228 | # task related
229 | parser.add_argument(
230 | "--text", type=str, required=True, help="Filename of train label data (json)"
231 | )
232 | parser.add_argument(
233 | "--pad", default=2, type=int, help="padd value at the end of each sentence"
234 | )
235 | return parser
236 |
237 |
238 | if __name__ == "__main__":
239 | print("Starting")
240 | main(sys.argv[1:])
241 |
--------------------------------------------------------------------------------
/nvidia_preprocessing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import tqdm
4 | import torch
5 | import argparse
6 | import numpy as np
7 | from utils.stft import TacotronSTFT
8 | from utils.util import read_wav_np
9 | from dataset.audio_processing import pitch
10 | from utils.hparams import HParam
11 |
12 |
13 | def main(args, hp):
14 | stft = TacotronSTFT(
15 | filter_length=hp.audio.n_fft,
16 | hop_length=hp.audio.hop_length,
17 | win_length=hp.audio.win_length,
18 | n_mel_channels=hp.audio.n_mels,
19 | sampling_rate=hp.audio.sample_rate,
20 | mel_fmin=hp.audio.fmin,
21 | mel_fmax=hp.audio.fmax,
22 | )
23 |
24 | wav_files = glob.glob(os.path.join(args.data_path, "**", "*.wav"), recursive=True)
25 | mel_path = os.path.join(hp.data.data_dir, "mels")
26 | energy_path = os.path.join(hp.data.data_dir, "energy")
27 | pitch_path = os.path.join(hp.data.data_dir, "pitch")
28 | os.makedirs(mel_path, exist_ok=True)
29 | os.makedirs(energy_path, exist_ok=True)
30 | os.makedirs(pitch_path, exist_ok=True)
31 | print("Sample Rate : ", hp.audio.sample_rate)
32 | for wavpath in tqdm.tqdm(wav_files, desc="preprocess wav to mel"):
33 | sr, wav = read_wav_np(wavpath, hp.audio.sample_rate)
34 | p = pitch(wav, hp) # [T, ] T = Number of frames
35 | wav = torch.from_numpy(wav).unsqueeze(0)
36 | mel, mag = stft.mel_spectrogram(wav) # mel [1, 80, T] mag [1, num_mag, T]
37 | mel = mel.squeeze(0) # [num_mel, T]
38 | mag = mag.squeeze(0) # [num_mag, T]
39 | e = torch.norm(mag, dim=0) # [T, ]
40 | p = p[: mel.shape[1]]
41 | id = os.path.basename(wavpath).split(".")[0]
42 | np.save("{}/{}.npy".format(mel_path, id), mel.numpy(), allow_pickle=False)
43 | np.save("{}/{}.npy".format(energy_path, id), e.numpy(), allow_pickle=False)
44 | np.save("{}/{}.npy".format(pitch_path, id), p, allow_pickle=False)
45 |
46 |
47 | if __name__ == "__main__":
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument(
50 | "-d", "--data_path", type=str, required=True, help="root directory of wav files"
51 | )
52 | parser.add_argument(
53 | "-c", "--config", type=str, required=True, help="yaml file for configuration"
54 | )
55 | args = parser.parse_args()
56 |
57 | hp = HParam(args.config)
58 |
59 | main(args, hp)
60 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.16.3
2 | librosa==0.7.0
3 | numba==0.48
4 | matplotlib
5 | unidecode
6 | inflect
7 | nltk
8 | tqdm
9 | pyyaml
10 | pyworld==0.2.10
11 | configargparse
12 | tensorboardX
13 | typeguard==2.9.1
14 | g2p_en
15 |
16 |
--------------------------------------------------------------------------------
/sample/generated_mel_58k.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/generated_mel_58k.npy
--------------------------------------------------------------------------------
/sample/sample2_58k.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample2_58k.wav
--------------------------------------------------------------------------------
/sample/sample_102k_melgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample_102k_melgan.wav
--------------------------------------------------------------------------------
/sample/sample_102k_waveglow.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample_102k_waveglow.wav
--------------------------------------------------------------------------------
/sample/sample_58k.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample_58k.wav
--------------------------------------------------------------------------------
/sample/sample_74k_melgan.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample_74k_melgan.wav
--------------------------------------------------------------------------------
/sample/sample_74k_waveglow.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/sample/sample_74k_waveglow.wav
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test_fastspeech2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils.hparams import HParam
3 | from dataset.texts import valid_symbols
4 | from fastspeech import FeedForwardTransformer
5 |
6 |
7 | def test_fastspeech():
8 | idim = len(valid_symbols)
9 | hp = HParam("configs/default.yaml")
10 | hp.train.ngpu = 0
11 | odim = hp.audio.num_mels
12 | model = FeedForwardTransformer(idim, odim, hp)
13 | x = torch.ones(2, 100).to(dtype=torch.int64)
14 | input_length = torch.tensor([100, 100])
15 | y = torch.ones(2, 100, 80)
16 | out_length = torch.tensor([100, 100])
17 | dur = torch.ones(2, 100)
18 | e = torch.ones(2, 100)
19 | p = torch.ones(2, 100)
20 | loss, report_dict = model(x, input_length, y, out_length, dur, e, p)
21 |
--------------------------------------------------------------------------------
/train_fastspeech.py:
--------------------------------------------------------------------------------
1 | import fastspeech
2 | from tensorboardX import SummaryWriter
3 | import torch
4 | from dataset import dataloader as loader
5 | import logging
6 | import math
7 | import os
8 | import sys
9 | import numpy as np
10 | import configargparse
11 | import random
12 | import tqdm
13 | import time
14 | from evaluation import evaluate
15 | from utils.plot import generate_audio, plot_spectrogram_to_numpy
16 | from core.optimizer import get_std_opt
17 | from utils.util import read_wav_np
18 | from dataset.texts import valid_symbols
19 | from utils.util import get_commit_hash
20 | from utils.hparams import HParam
21 |
22 | BATCH_COUNT_CHOICES = ["auto", "seq", "bin", "frame"]
23 | BATCH_SORT_KEY_CHOICES = ["input", "output", "shuffle"]
24 |
25 |
26 | def train(args, hp, hp_str, logger, vocoder):
27 | os.makedirs(os.path.join(hp.train.chkpt_dir, args.name), exist_ok=True)
28 | os.makedirs(os.path.join(args.outdir, args.name), exist_ok=True)
29 | os.makedirs(os.path.join(args.outdir, args.name, "assets"), exist_ok=True)
30 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")
31 |
32 | dataloader = loader.get_tts_dataset(hp.data.data_dir, hp.train.batch_size, hp)
33 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, hp, True)
34 |
35 | idim = len(valid_symbols)
36 | odim = hp.audio.num_mels
37 | model = fastspeech.FeedForwardTransformer(idim, odim, hp)
38 | # set torch device
39 | model = model.to(device)
40 | print("Model is loaded ...")
41 | githash = get_commit_hash()
42 | if args.checkpoint_path is not None:
43 | if os.path.exists(args.checkpoint_path):
44 | logger.info("Resuming from checkpoint: %s" % args.checkpoint_path)
45 | checkpoint = torch.load(args.checkpoint_path)
46 | model.load_state_dict(checkpoint["model"])
47 | optimizer = get_std_opt(
48 | model,
49 | hp.model.adim,
50 | hp.model.transformer_warmup_steps,
51 | hp.model.transformer_lr,
52 | )
53 | optimizer.load_state_dict(checkpoint["optim"])
54 | global_step = checkpoint["step"]
55 |
56 | if hp_str != checkpoint["hp_str"]:
57 | logger.warning(
58 | "New hparams is different from checkpoint. Will use new."
59 | )
60 |
61 | if githash != checkpoint["githash"]:
62 | logger.warning("Code might be different: git hash is different.")
63 | logger.warning("%s -> %s" % (checkpoint["githash"], githash))
64 |
65 | else:
66 | print("Checkpoint does not exixts")
67 | global_step = 0
68 | return None
69 | else:
70 | print("New Training")
71 | global_step = 0
72 | optimizer = get_std_opt(
73 | model,
74 | hp.model.adim,
75 | hp.model.transformer_warmup_steps,
76 | hp.model.transformer_lr,
77 | )
78 |
79 | print("Batch Size :", hp.train.batch_size)
80 |
81 | num_params(model)
82 |
83 | os.makedirs(os.path.join(hp.train.log_dir, args.name), exist_ok=True)
84 | writer = SummaryWriter(os.path.join(hp.train.log_dir, args.name))
85 | model.train()
86 | forward_count = 0
87 | # print(model)
88 | for epoch in range(hp.train.epochs):
89 | start = time.time()
90 | running_loss = 0
91 | j = 0
92 |
93 | pbar = tqdm.tqdm(dataloader, desc="Loading train data")
94 | for data in pbar:
95 | global_step += 1
96 | x, input_length, y, _, out_length, _, dur, e, p = data
97 | # x : [batch , num_char], input_length : [batch], y : [batch, T_in, num_mel]
98 | # # stop_token : [batch, T_in], out_length : [batch]
99 |
100 | loss, report_dict = model(
101 | x.cuda(),
102 | input_length.cuda(),
103 | y.cuda(),
104 | out_length.cuda(),
105 | dur.cuda(),
106 | e.cuda(),
107 | p.cuda(),
108 | )
109 | loss = loss.mean() / hp.train.accum_grad
110 | running_loss += loss.item()
111 |
112 | loss.backward()
113 |
114 | # update parameters
115 | forward_count += 1
116 | j = j + 1
117 | if forward_count != hp.train.accum_grad:
118 | continue
119 | forward_count = 0
120 | step = global_step
121 |
122 | # compute the gradient norm to check if it is normal or not
123 | grad_norm = torch.nn.utils.clip_grad_norm_(
124 | model.parameters(), hp.train.grad_clip
125 | )
126 | logging.debug("grad norm={}".format(grad_norm))
127 | if math.isnan(grad_norm):
128 | logging.warning("grad norm is nan. Do not update model.")
129 | else:
130 | optimizer.step()
131 | optimizer.zero_grad()
132 |
133 | if step % hp.train.summary_interval == 0:
134 | pbar.set_description(
135 | "Average Loss %.04f Loss %.04f | step %d"
136 | % (running_loss / j, loss.item(), step)
137 | )
138 |
139 | for r in report_dict:
140 | for k, v in r.items():
141 | if k is not None and v is not None:
142 | if "cupy" in str(type(v)):
143 | v = v.get()
144 | if "cupy" in str(type(k)):
145 | k = k.get()
146 | writer.add_scalar("main/{}".format(k), v, step)
147 |
148 | if step % hp.train.validation_step == 0:
149 |
150 | for valid in validloader:
151 | x_, input_length_, y_, _, out_length_, ids_, dur_, e_, p_ = valid
152 | model.eval()
153 | with torch.no_grad():
154 | loss_, report_dict_ = model(
155 | x_.cuda(),
156 | input_length_.cuda(),
157 | y_.cuda(),
158 | out_length_.cuda(),
159 | dur_.cuda(),
160 | e_.cuda(),
161 | p_.cuda(),
162 | )
163 |
164 | mels_ = model.inference(x_[-1].cuda()) # [T, num_mel]
165 |
166 | model.train()
167 | for r in report_dict_:
168 | for k, v in r.items():
169 | if k is not None and v is not None:
170 | if "cupy" in str(type(v)):
171 | v = v.get()
172 | if "cupy" in str(type(k)):
173 | k = k.get()
174 | writer.add_scalar("validation/{}".format(k), v, step)
175 |
176 | mels_ = mels_.T # Out: [num_mels, T]
177 | writer.add_image(
178 | "melspectrogram_target_{}".format(ids_[-1]),
179 | plot_spectrogram_to_numpy(
180 | y_[-1].T.data.cpu().numpy()[:, : out_length_[-1]]
181 | ),
182 | step,
183 | dataformats="HWC",
184 | )
185 | writer.add_image(
186 | "melspectrogram_prediction_{}".format(ids_[-1]),
187 | plot_spectrogram_to_numpy(mels_.data.cpu().numpy()),
188 | step,
189 | dataformats="HWC",
190 | )
191 |
192 | # print(mels.unsqueeze(0).shape)
193 |
194 | audio = generate_audio(
195 | mels_.unsqueeze(0), vocoder
196 | ) # selecting the last data point to match mel generated above
197 | audio = audio.cpu().float().numpy()
198 | audio = audio / (
199 | audio.max() - audio.min()
200 | ) # get values between -1 and 1
201 |
202 | writer.add_audio(
203 | tag="generated_audio_{}".format(ids_[-1]),
204 | snd_tensor=torch.Tensor(audio),
205 | global_step=step,
206 | sample_rate=hp.audio.sample_rate,
207 | )
208 |
209 | _, target = read_wav_np(
210 | hp.data.wav_dir + f"{ids_[-1]}.wav",
211 | sample_rate=hp.audio.sample_rate,
212 | )
213 |
214 | writer.add_audio(
215 | tag=" target_audio_{}".format(ids_[-1]),
216 | snd_tensor=torch.Tensor(target),
217 | global_step=step,
218 | sample_rate=hp.audio.sample_rate,
219 | )
220 |
221 | ##
222 | if step % hp.train.save_interval == 0:
223 | avg_p, avg_e, avg_d = evaluate(hp, validloader, model)
224 | writer.add_scalar("evaluation/Pitch_Loss", avg_p, step)
225 | writer.add_scalar("evaluation/Energy_Loss", avg_e, step)
226 | writer.add_scalar("evaluation/Dur_Loss", avg_d, step)
227 | save_path = os.path.join(
228 | hp.train.chkpt_dir,
229 | args.name,
230 | "{}_fastspeech_{}_{}k_steps.pyt".format(
231 | args.name, githash, step // 1000
232 | ),
233 | )
234 |
235 | torch.save(
236 | {
237 | "model": model.state_dict(),
238 | "optim": optimizer.state_dict(),
239 | "step": step,
240 | "hp_str": hp_str,
241 | "githash": githash,
242 | },
243 | save_path,
244 | )
245 | logger.info("Saved checkpoint to: %s" % save_path)
246 | print(
247 | "Time taken for epoch {} is {} sec\n".format(
248 | epoch + 1, int(time.time() - start)
249 | )
250 | )
251 |
252 |
253 | def num_params(model, print_out=True):
254 | parameters = filter(lambda p: p.requires_grad, model.parameters())
255 | parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
256 | if print_out:
257 | print("Trainable Parameters: %.3fM" % parameters)
258 |
259 |
260 | def create_gta(args, hp, hp_str, logger):
261 | os.makedirs(os.path.join(hp.data.data_dir, "gta"), exist_ok=True)
262 | device = torch.device("cuda" if hp.train.ngpu > 0 else "cpu")
263 |
264 | dataloader = loader.get_tts_dataset(hp.data.data_dir, 1)
265 | validloader = loader.get_tts_dataset(hp.data.data_dir, 1, True)
266 | global_step = 0
267 | idim = len(valid_symbols)
268 | odim = hp.audio.num_mels
269 | model = fastspeech.FeedForwardTransformer(idim, odim, args)
270 | # set torch device
271 | if os.path.exists(args.checkpoint_path):
272 | print("\nSynthesis GTA Session...\n")
273 | checkpoint = torch.load(args.checkpoint_path)
274 | model.load_state_dict(checkpoint["model"])
275 | else:
276 | print("Checkpoint not exixts")
277 | return None
278 | model.eval()
279 | model = model.to(device)
280 | print("Model is loaded ...")
281 | print("Batch Size :", hp.train.batch_size)
282 | num_params(model)
283 | onlyValidation = False
284 | if not onlyValidation:
285 | pbar = tqdm.tqdm(dataloader, desc="Loading train data")
286 | for data in pbar:
287 | # start_b = time.time()
288 | global_step += 1
289 | x, input_length, y, _, out_length, ids = data
290 | with torch.no_grad():
291 | _, gta, _, _, _ = model._forward(
292 | x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda()
293 | )
294 | # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=False)
295 | gta = gta.cpu().numpy()
296 |
297 | for j in range(len(ids)):
298 | mel = gta[j]
299 | mel = mel.T
300 | mel = mel[:, : out_length[j]]
301 | mel = (mel + 4) / 8
302 | id = ids[j]
303 | np.save(
304 | "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"), id),
305 | mel,
306 | allow_pickle=False,
307 | )
308 |
309 | pbar = tqdm.tqdm(validloader, desc="Loading Valid data")
310 | for data in pbar:
311 | # start_b = time.time()
312 | global_step += 1
313 | x, input_length, y, _, out_length, ids = data
314 | with torch.no_grad():
315 | gta, _, _ = model._forward(
316 | x.cuda(), input_length.cuda(), y.cuda(), out_length.cuda()
317 | )
318 | # gta = model._forward(x.cuda(), input_length.cuda(), is_inference=True)
319 | gta = gta.cpu().numpy()
320 |
321 | for j in range(len(ids)):
322 | print("Actual mel specs : {} = {}".format(ids[j], y[j].shape))
323 | print("Out length:", out_length[j])
324 | print("GTA size: {} = {}".format(ids[j], gta[j].shape))
325 | mel = gta[j]
326 | mel = mel.T
327 | mel = mel[:, : out_length[j]]
328 | mel = (mel + 4) / 8
329 | print("Mel size: {} = {}".format(ids[j], mel.shape))
330 | id = ids[j]
331 | np.save(
332 | "{}/{}.npy".format(os.path.join(hp.data.data_dir, "gta"), id),
333 | mel,
334 | allow_pickle=False,
335 | )
336 |
337 |
338 | # define function for plot prob and att_ws
339 | def _plot_and_save(array, figname, figsize=(6, 4), dpi=150):
340 | import matplotlib.pyplot as plt
341 |
342 | shape = array.shape
343 | if len(shape) == 1:
344 | # for eos probability
345 | fig = plt.figure(figsize=figsize, dpi=dpi)
346 | plt.plot(array)
347 | plt.xlabel("Frame")
348 | plt.ylabel("Probability")
349 | plt.ylim([0, 1])
350 | elif len(shape) == 2:
351 | # for tacotron 2 attention weights, whose shape is (out_length, in_length)
352 | fig = plt.figure(figsize=figsize, dpi=dpi)
353 | plt.imshow(array, aspect="auto")
354 | plt.xlabel("Input")
355 | plt.ylabel("Output")
356 | elif len(shape) == 4:
357 | # for transformer attention weights, whose shape is (#leyers, #heads, out_length, in_length)
358 | fig = plt.figure(
359 | figsize=(figsize[0] * shape[0], figsize[1] * shape[1]), dpi=dpi
360 | )
361 | for idx1, xs in enumerate(array):
362 | for idx2, x in enumerate(xs, 1):
363 | plt.subplot(shape[0], shape[1], idx1 * shape[1] + idx2)
364 | plt.imshow(x.cpu().detach().numpy(), aspect="auto")
365 | plt.xlabel("Input")
366 | plt.ylabel("Output")
367 | else:
368 | raise NotImplementedError("Support only from 1D to 4D array.")
369 | plt.tight_layout()
370 | if not os.path.exists(os.path.dirname(figname)):
371 | # NOTE: exist_ok = True is needed for parallel process decoding
372 | os.makedirs(os.path.dirname(figname), exist_ok=True)
373 | plt.savefig(figname)
374 | plt.close()
375 | return fig
376 |
377 |
378 | # NOTE: you need this func to generate our sphinx doc
379 | def get_parser():
380 | """Get parser of training arguments."""
381 | parser = configargparse.ArgumentParser(
382 | description="Train a new text-to-speech (TTS) model on one CPU, one or multiple GPUs",
383 | config_file_parser_class=configargparse.YAMLConfigFileParser,
384 | formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
385 | )
386 |
387 | parser.add_argument(
388 | "-c", "--config", type=str, required=True, help="yaml file for configuration"
389 | )
390 | parser.add_argument(
391 | "-p",
392 | "--checkpoint_path",
393 | type=str,
394 | default=None,
395 | help="path of checkpoint pt file to resume training",
396 | )
397 | parser.add_argument(
398 | "-n",
399 | "--name",
400 | type=str,
401 | required=True,
402 | help="name of the model for logging, saving checkpoint",
403 | )
404 | parser.add_argument("--outdir", type=str, required=True, help="Output directory")
405 |
406 | return parser
407 |
408 |
409 | def main(cmd_args):
410 | """Run training."""
411 | parser = get_parser()
412 | args, _ = parser.parse_known_args(cmd_args)
413 |
414 | args = parser.parse_args(cmd_args)
415 |
416 | hp = HParam(args.config)
417 | with open(args.config, "r") as f:
418 | hp_str = "".join(f.readlines())
419 |
420 | # logging info
421 | os.makedirs(hp.train.log_dir, exist_ok=True)
422 | logging.basicConfig(
423 | level=logging.INFO,
424 | format="%(asctime)s - %(levelname)s - %(message)s",
425 | handlers=[
426 | logging.FileHandler(
427 | os.path.join(hp.train.log_dir, "%s-%d.log" % (args.name, time.time()))
428 | ),
429 | logging.StreamHandler(),
430 | ],
431 | )
432 | logger = logging.getLogger()
433 |
434 | # If --ngpu is not given,
435 | # 1. if CUDA_VISIBLE_DEVICES is set, all visible devices
436 | # 2. if nvidia-smi exists, use all devices
437 | # 3. else ngpu=0
438 | ngpu = hp.train.ngpu
439 | logger.info(f"ngpu: {ngpu}")
440 |
441 | # set random seed
442 | logger.info("random seed = %d" % hp.train.seed)
443 | random.seed(hp.train.seed)
444 | np.random.seed(hp.train.seed)
445 |
446 | vocoder = torch.hub.load(
447 | "seungwonpark/melgan", "melgan"
448 | ) # load the vocoder for validation
449 |
450 | if hp.train.GTA:
451 | create_gta(args, hp, hp_str, logger)
452 | else:
453 | train(args, hp, hp_str, logger, vocoder)
454 |
455 |
456 | if __name__ == "__main__":
457 | main(sys.argv[1:])
458 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rishikksh20/FastSpeech2/b0e5b931e1f66a3594e03aa5aa7bfc511eaff23c/utils/__init__.py
--------------------------------------------------------------------------------
/utils/display.py:
--------------------------------------------------------------------------------
1 | import time
2 | import sys
3 | import matplotlib
4 |
5 | matplotlib.use("Agg")
6 |
7 |
8 | def progbar(i, n, size=16):
9 | done = (i * size) // n
10 | bar = ""
11 | for i in range(size):
12 | bar += "█" if i <= done else "░"
13 | return bar
14 |
15 |
16 | def stream(message):
17 | sys.stdout.write(f"\r{message}")
18 |
19 |
20 | def simple_table(item_tuples):
21 | border_pattern = "+---------------------------------------"
22 | whitespace = " "
23 |
24 | headings, cells, = (
25 | [],
26 | [],
27 | )
28 |
29 | for item in item_tuples:
30 |
31 | heading, cell = str(item[0]), str(item[1])
32 |
33 | pad_head = True if len(heading) < len(cell) else False
34 |
35 | pad = abs(len(heading) - len(cell))
36 | pad = whitespace[:pad]
37 |
38 | pad_left = pad[: len(pad) // 2]
39 | pad_right = pad[len(pad) // 2 :]
40 |
41 | if pad_head:
42 | heading = pad_left + heading + pad_right
43 | else:
44 | cell = pad_left + cell + pad_right
45 |
46 | headings += [heading]
47 | cells += [cell]
48 |
49 | border, head, body = "", "", ""
50 |
51 | for i in range(len(item_tuples)):
52 |
53 | temp_head = f"| {headings[i]} "
54 | temp_body = f"| {cells[i]} "
55 |
56 | border += border_pattern[: len(temp_head)]
57 | head += temp_head
58 | body += temp_body
59 |
60 | if i == len(item_tuples) - 1:
61 | head += "|"
62 | body += "|"
63 | border += "+"
64 |
65 | print(border)
66 | print(head)
67 | print(border)
68 | print(body)
69 | print(border)
70 | print(" ")
71 |
72 |
73 | def time_since(started):
74 | elapsed = time.time() - started
75 | m = int(elapsed // 60)
76 | s = int(elapsed % 60)
77 | if m >= 60:
78 | h = int(m // 60)
79 | m = m % 60
80 | return f"{h}h {m}m {s}s"
81 | else:
82 | return f"{m}m {s}s"
83 |
--------------------------------------------------------------------------------
/utils/fastspeech2_script.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 |
4 | # Copyright 2019 Tomoki Hayashi
5 | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
6 |
7 | """FastSpeech related loss."""
8 |
9 | import logging
10 |
11 | import torch
12 | from core.duration_modeling.duration_predictor import DurationPredictor
13 | from core.duration_modeling.duration_predictor import DurationPredictorLoss
14 | from core.variance_predictor import EnergyPredictor, EnergyPredictorLoss
15 | from core.variance_predictor import PitchPredictor, PitchPredictorLoss
16 | from core.duration_modeling.length_regulator import LengthRegulator
17 | from utils.util import make_non_pad_mask_script
18 | from utils.util import make_pad_mask_script
19 | from core.embedding import PositionalEncoding
20 | from core.embedding import ScaledPositionalEncoding
21 | from core.encoder import Encoder
22 | from core.modules import initialize
23 | from core.modules import Postnet
24 | from typeguard import check_argument_types
25 | from typing import Dict, Tuple, Sequence
26 |
27 |
28 | class FeedForwardTransformer(torch.nn.Module):
29 | def __init__(self, idim: int, odim: int, hp: Dict):
30 | """Initialize feed-forward Transformer module.
31 | Args:
32 | idim (int): Dimension of the inputs.
33 | odim (int): Dimension of the outputs.
34 | """
35 | # initialize base classes
36 | assert check_argument_types()
37 | torch.nn.Module.__init__(self)
38 |
39 | # fill missing arguments
40 |
41 | # store hyperparameters
42 | self.idim = idim
43 | self.odim = odim
44 |
45 | self.use_scaled_pos_enc = hp.model.use_scaled_pos_enc
46 | self.use_masking = hp.model.use_masking
47 |
48 | # use idx 0 as padding idx
49 | padding_idx = 0
50 |
51 | # get positional encoding class
52 | pos_enc_class = (
53 | ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding
54 | )
55 |
56 | # define encoder
57 | encoder_input_layer = torch.nn.Embedding(
58 | num_embeddings=idim, embedding_dim=hp.model.adim, padding_idx=padding_idx
59 | )
60 | self.encoder = Encoder(
61 | idim=idim,
62 | attention_dim=hp.model.adim,
63 | attention_heads=hp.model.aheads,
64 | linear_units=hp.model.eunits,
65 | num_blocks=hp.model.elayers,
66 | input_layer=encoder_input_layer,
67 | dropout_rate=0.2,
68 | positional_dropout_rate=0.2,
69 | attention_dropout_rate=0.2,
70 | pos_enc_class=pos_enc_class,
71 | normalize_before=hp.model.encoder_normalize_before,
72 | concat_after=hp.model.encoder_concat_after,
73 | positionwise_layer_type=hp.model.positionwise_layer_type,
74 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size,
75 | )
76 |
77 | self.duration_predictor = DurationPredictor(
78 | idim=hp.model.adim,
79 | n_layers=hp.model.duration_predictor_layers,
80 | n_chans=hp.model.duration_predictor_chans,
81 | kernel_size=hp.model.duration_predictor_kernel_size,
82 | dropout_rate=hp.model.duration_predictor_dropout_rate,
83 | )
84 |
85 | self.energy_predictor = EnergyPredictor(
86 | idim=hp.model.adim,
87 | n_layers=hp.model.duration_predictor_layers,
88 | n_chans=hp.model.duration_predictor_chans,
89 | kernel_size=hp.model.duration_predictor_kernel_size,
90 | dropout_rate=hp.model.duration_predictor_dropout_rate,
91 | min=hp.data.e_min,
92 | max=hp.data.e_max,
93 | )
94 | self.energy_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)
95 |
96 | self.pitch_predictor = PitchPredictor(
97 | idim=hp.model.adim,
98 | n_layers=hp.model.duration_predictor_layers,
99 | n_chans=hp.model.duration_predictor_chans,
100 | kernel_size=hp.model.duration_predictor_kernel_size,
101 | dropout_rate=hp.model.duration_predictor_dropout_rate,
102 | min=hp.data.p_min,
103 | max=hp.data.p_max,
104 | )
105 | self.pitch_embed = torch.nn.Linear(hp.model.adim, hp.model.adim)
106 |
107 | # define length regulator
108 | self.length_regulator = LengthRegulator()
109 |
110 | # define decoder
111 | # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
112 | self.decoder = Encoder(
113 | idim=256,
114 | attention_dim=256,
115 | attention_heads=hp.model.aheads,
116 | linear_units=hp.model.dunits,
117 | num_blocks=hp.model.dlayers,
118 | input_layer=None,
119 | dropout_rate=0.2,
120 | positional_dropout_rate=0.2,
121 | attention_dropout_rate=0.2,
122 | pos_enc_class=pos_enc_class,
123 | normalize_before=hp.model.decoder_normalize_before,
124 | concat_after=hp.model.decoder_concat_after,
125 | positionwise_layer_type=hp.model.positionwise_layer_type,
126 | positionwise_conv_kernel_size=hp.model.positionwise_conv_kernel_size,
127 | )
128 |
129 | # define postnet
130 | self.postnet = (
131 | None
132 | if hp.model.postnet_layers == 0
133 | else Postnet(
134 | idim=idim,
135 | odim=odim,
136 | n_layers=hp.model.postnet_layers,
137 | n_chans=hp.model.postnet_chans,
138 | n_filts=hp.model.postnet_filts,
139 | use_batch_norm=hp.model.use_batch_norm,
140 | dropout_rate=hp.model.postnet_dropout_rate,
141 | )
142 | )
143 |
144 | # define final projection
145 | self.feat_out = torch.nn.Linear(hp.model.adim, odim * hp.model.reduction_factor)
146 |
147 | # initialize parameters
148 | self._reset_parameters(
149 | init_type=hp.model.transformer_init,
150 | init_enc_alpha=hp.model.initial_encoder_alpha,
151 | init_dec_alpha=hp.model.initial_decoder_alpha,
152 | )
153 |
154 | # define criterions
155 | self.duration_criterion = DurationPredictorLoss()
156 | self.energy_criterion = EnergyPredictorLoss()
157 | self.pitch_criterion = PitchPredictorLoss()
158 | self.criterion = torch.nn.L1Loss(reduction="mean")
159 | self.use_weighted_masking = hp.model.use_weighted_masking
160 |
161 | def _forward(self, xs: torch.Tensor, ilens: torch.Tensor):
162 | # forward encoder
163 | x_masks = self._source_mask(
164 | ilens
165 | ) # (B, Tmax, Tmax) -> torch.Size([32, 121, 121])
166 |
167 | hs, _ = self.encoder(
168 | xs, x_masks
169 | ) # (B, Tmax, adim) -> torch.Size([32, 121, 256])
170 | # print("ys :", ys.shape)
171 |
172 | # # forward duration predictor and length regulator
173 | d_masks = make_pad_mask_script(ilens).to(xs.device)
174 |
175 | d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax)
176 | hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim)
177 |
178 | one_hot_energy = self.energy_predictor.inference(hs) # (B, Lmax, adim)
179 |
180 | one_hot_pitch = self.pitch_predictor.inference(hs) # (B, Lmax, adim)
181 |
182 | hs = hs + self.pitch_embed(one_hot_pitch) # (B, Lmax, adim)
183 | hs = hs + self.energy_embed(one_hot_energy) # (B, Lmax, adim)
184 |
185 | # # forward decoder
186 | # h_masks = self._source_mask(olens) we can find olens from length regulator and then calculate mask
187 | # h_masks = torch.empty(0)
188 |
189 | zs, _ = self.decoder(hs, None) # (B, Lmax, adim)
190 |
191 | before_outs = self.feat_out(zs).view(
192 | zs.size(0), -1, self.odim
193 | ) # (B, Lmax, odim)
194 |
195 | # postnet -> (B, Lmax//r * r, odim)
196 | after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(
197 | 1, 2
198 | )
199 | return after_outs
200 |
201 | def forward(self, x: torch.Tensor) -> torch.Tensor:
202 | """Generate the sequence of features given the sequences of characters.
203 | Args:
204 | x (Tensor): Input sequence of characters (T,).
205 | inference_args (Namespace): Dummy for compatibility.
206 | spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).
207 | Returns:
208 | Tensor: Output sequence of features (1, L, odim).
209 | None: Dummy for compatibility.
210 | None: Dummy for compatibility.
211 | """
212 | # setup batch axis
213 | ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device)
214 | xs = x.unsqueeze(0)
215 |
216 | # inference
217 | outs = self._forward(xs, ilens) # (L, odim)
218 |
219 | return outs[0]
220 |
221 | def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor:
222 | """Make masks for self-attention.
223 | Examples:
224 | >>> ilens = [5, 3]
225 | >>> self._source_mask(ilens)
226 | tensor([[[1, 1, 1, 1, 1],
227 | [1, 1, 1, 1, 1],
228 | [1, 1, 1, 1, 1],
229 | [1, 1, 1, 1, 1],
230 | [1, 1, 1, 1, 1]],
231 | [[1, 1, 1, 0, 0],
232 | [1, 1, 1, 0, 0],
233 | [1, 1, 1, 0, 0],
234 | [0, 0, 0, 0, 0],
235 | [0, 0, 0, 0, 0]]], dtype=torch.uint8)
236 | """
237 | x_masks = make_non_pad_mask_script(ilens)
238 | return x_masks.unsqueeze(-2) & x_masks.unsqueeze(-1)
239 |
240 | def _reset_parameters(
241 | self, init_type: str, init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0
242 | ):
243 | # initialize parameters
244 | initialize(self, init_type)
245 | #
246 | # initialize alpha in scaled positional encoding
247 | if self.use_scaled_pos_enc:
248 | self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha)
249 | self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
250 |
--------------------------------------------------------------------------------
/utils/hparams.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 |
4 |
5 | def load_hparam_str(hp_str):
6 | path = "temp-restore.yaml"
7 | with open(path, "w") as f:
8 | f.write(hp_str)
9 | ret = HParam(path)
10 | os.remove(path)
11 | return ret
12 |
13 |
14 | def load_hparam(filename):
15 | stream = open(filename, "r")
16 | docs = yaml.load_all(stream, Loader=yaml.Loader)
17 | hparam_dict = dict()
18 | for doc in docs:
19 | for k, v in doc.items():
20 | hparam_dict[k] = v
21 | return hparam_dict
22 |
23 |
24 | def merge_dict(user, default):
25 | if isinstance(user, dict) and isinstance(default, dict):
26 | for k, v in default.items():
27 | if k not in user:
28 | user[k] = v
29 | else:
30 | user[k] = merge_dict(user[k], v)
31 | return user
32 |
33 |
34 | class Dotdict(dict):
35 | """
36 | a dictionary that supports dot notation
37 | as well as dictionary access notation
38 | usage: d = DotDict() or d = DotDict({'val1':'first'})
39 | set attributes: d.val2 = 'second' or d['val2'] = 'second'
40 | get attributes: d.val2 or d['val2']
41 | """
42 |
43 | __getattr__ = dict.__getitem__
44 | __setattr__ = dict.__setitem__
45 | __delattr__ = dict.__delitem__
46 |
47 | def __init__(self, dct=None):
48 | dct = dict() if not dct else dct
49 | for key, value in dct.items():
50 | if hasattr(value, "keys"):
51 | value = Dotdict(value)
52 | self[key] = value
53 |
54 |
55 | class HParam(Dotdict):
56 | def __init__(self, file):
57 | super(Dotdict, self).__init__()
58 | hp_dict = load_hparam(file)
59 | hp_dotdict = Dotdict(hp_dict)
60 | for k, v in hp_dotdict.items():
61 | setattr(self, k, v)
62 |
63 | __getattr__ = Dotdict.__getitem__
64 | __setattr__ = Dotdict.__setitem__
65 | __delattr__ = Dotdict.__delitem__
66 |
--------------------------------------------------------------------------------
/utils/plot.py:
--------------------------------------------------------------------------------
1 | import matplotlib
2 | import numpy as np
3 | import torch
4 |
5 | matplotlib.use("Agg")
6 | from matplotlib import pyplot as plt
7 |
8 |
9 | def save_attention(attn, path):
10 | fig = plt.figure(figsize=(12, 6))
11 | plt.imshow(attn.T, interpolation="nearest", aspect="auto")
12 | fig.savefig(f"{path}.png", bbox_inches="tight")
13 | plt.close(fig)
14 |
15 |
16 | def save_spectrogram(M, path, length=None):
17 | M = np.flip(M, axis=0)
18 | if length:
19 | M = M[:, :length]
20 | fig = plt.figure(figsize=(12, 6))
21 | plt.imshow(M, interpolation="nearest", aspect="auto")
22 | fig.savefig(f"{path}.png", bbox_inches="tight")
23 | plt.close(fig)
24 |
25 |
26 | def plot(array):
27 | fig = plt.figure(figsize=(30, 5))
28 | ax = fig.add_subplot(111)
29 | ax.xaxis.label.set_color("grey")
30 | ax.yaxis.label.set_color("grey")
31 | ax.xaxis.label.set_fontsize(23)
32 | ax.yaxis.label.set_fontsize(23)
33 | ax.tick_params(axis="x", colors="grey", labelsize=23)
34 | ax.tick_params(axis="y", colors="grey", labelsize=23)
35 | plt.plot(array)
36 |
37 |
38 | def plot_spec(M):
39 | M = np.flip(M, axis=0)
40 | plt.figure(figsize=(18, 4))
41 | plt.imshow(M, interpolation="nearest", aspect="auto")
42 | plt.show()
43 |
44 |
45 | def plot_image(target, melspec, mel_lengths): # , alignments
46 | fig, axes = plt.subplots(2, 1, figsize=(20, 20))
47 | T = mel_lengths[-1]
48 |
49 | axes[0].imshow(target[-1].T.detach().cpu()[:, :T], origin="lower", aspect="auto")
50 |
51 | axes[1].imshow(melspec.cpu()[:, :T], origin="lower", aspect="auto")
52 |
53 | return fig
54 |
55 |
56 | def save_figure_to_numpy(fig, spectrogram=False):
57 | # save it to a numpy array.
58 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
59 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
60 | if spectrogram:
61 | return data
62 | data = np.transpose(data, (2, 0, 1))
63 | return data
64 |
65 |
66 | def plot_waveform_to_numpy(waveform):
67 | fig, ax = plt.subplots(figsize=(12, 3))
68 | ax.plot()
69 | ax.plot(range(len(waveform)), waveform, linewidth=0.1, alpha=0.7, color="blue")
70 |
71 | plt.xlabel("Samples")
72 | plt.ylabel("Amplitude")
73 | plt.ylim(-1, 1)
74 | plt.tight_layout()
75 |
76 | fig.canvas.draw()
77 | data = save_figure_to_numpy(fig)
78 | plt.close()
79 | return data
80 |
81 |
82 | def plot_spectrogram_to_numpy(spectrogram):
83 | fig, ax = plt.subplots(figsize=(12, 3))
84 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
85 | plt.colorbar(im, ax=ax)
86 | plt.xlabel("Frames")
87 | plt.ylabel("Channels")
88 | plt.tight_layout()
89 |
90 | fig.canvas.draw()
91 | data = save_figure_to_numpy(fig, True)
92 | plt.close()
93 | return data
94 |
95 |
96 | def generate_audio(mel, vocoder):
97 | # input mel shape - [1,80,T]
98 | vocoder.eval()
99 | if torch.cuda.is_available():
100 | vocoder = vocoder.cuda()
101 | mel = mel.cuda()
102 |
103 | with torch.no_grad():
104 | audio = vocoder.inference(mel)
105 | return audio
106 |
--------------------------------------------------------------------------------
/utils/stft.py:
--------------------------------------------------------------------------------
1 | """
2 | BSD 3-Clause License
3 | Copyright (c) 2017, Prem Seetharaman
4 | All rights reserved.
5 | * Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 | * Redistributions of source code must retain the above copyright notice,
8 | this list of conditions and the following disclaimer.
9 | * Redistributions in binary form must reproduce the above copyright notice, this
10 | list of conditions and the following disclaimer in the
11 | documentation and/or other materials provided with the distribution.
12 | * Neither the name of the copyright holder nor the names of its
13 | contributors may be used to endorse or promote products derived from this
14 | software without specific prior written permission.
15 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25 | """
26 |
27 | import torch
28 | import numpy as np
29 | import torch.nn.functional as F
30 | from torch.autograd import Variable
31 | from scipy.signal import get_window
32 | from librosa.util import pad_center, tiny
33 | from dataset.audio_processing import (
34 | window_sumsquare,
35 | dynamic_range_compression,
36 | dynamic_range_decompression,
37 | )
38 | from librosa.filters import mel as librosa_mel_fn
39 |
40 |
41 | class STFT(torch.nn.Module):
42 | """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
43 |
44 | def __init__(
45 | self, filter_length=800, hop_length=200, win_length=800, window="hann"
46 | ):
47 | super(STFT, self).__init__()
48 | self.filter_length = filter_length
49 | self.hop_length = hop_length
50 | self.win_length = win_length
51 | self.window = window
52 | self.forward_transform = None
53 | scale = self.filter_length / self.hop_length
54 | fourier_basis = np.fft.fft(np.eye(self.filter_length))
55 |
56 | cutoff = int((self.filter_length / 2 + 1))
57 | fourier_basis = np.vstack(
58 | [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
59 | )
60 |
61 | forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
62 | inverse_basis = torch.FloatTensor(
63 | np.linalg.pinv(scale * fourier_basis).T[:, None, :]
64 | )
65 |
66 | if window is not None:
67 | assert filter_length >= win_length
68 | # get window and zero center pad it to filter_length
69 | fft_window = get_window(window, win_length, fftbins=True)
70 | fft_window = pad_center(fft_window, filter_length)
71 | fft_window = torch.from_numpy(fft_window).float()
72 |
73 | # window the bases
74 | forward_basis *= fft_window
75 | inverse_basis *= fft_window
76 |
77 | self.register_buffer("forward_basis", forward_basis.float())
78 | self.register_buffer("inverse_basis", inverse_basis.float())
79 |
80 | def transform(self, input_data):
81 | num_batches = input_data.size(0)
82 | num_samples = input_data.size(1)
83 |
84 | self.num_samples = num_samples
85 |
86 | # similar to librosa, reflect-pad the input
87 | input_data = input_data.view(num_batches, 1, num_samples)
88 | input_data = F.pad(
89 | input_data.unsqueeze(1),
90 | (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
91 | mode="reflect",
92 | )
93 | input_data = input_data.squeeze(1)
94 |
95 | # https://github.com/NVIDIA/tacotron2/issues/125
96 | forward_transform = F.conv1d(
97 | input_data.cuda(),
98 | Variable(self.forward_basis, requires_grad=False).cuda(),
99 | stride=self.hop_length,
100 | padding=0,
101 | ).cpu()
102 |
103 | cutoff = int((self.filter_length / 2) + 1)
104 | real_part = forward_transform[:, :cutoff, :]
105 | imag_part = forward_transform[:, cutoff:, :]
106 |
107 | magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
108 | phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
109 |
110 | return magnitude, phase
111 |
112 | def inverse(self, magnitude, phase):
113 | recombine_magnitude_phase = torch.cat(
114 | [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
115 | )
116 |
117 | inverse_transform = F.conv_transpose1d(
118 | recombine_magnitude_phase,
119 | Variable(self.inverse_basis, requires_grad=False),
120 | stride=self.hop_length,
121 | padding=0,
122 | )
123 |
124 | if self.window is not None:
125 | window_sum = window_sumsquare(
126 | self.window,
127 | magnitude.size(-1),
128 | hop_length=self.hop_length,
129 | win_length=self.win_length,
130 | n_fft=self.filter_length,
131 | dtype=np.float32,
132 | )
133 | # remove modulation effects
134 | approx_nonzero_indices = torch.from_numpy(
135 | np.where(window_sum > tiny(window_sum))[0]
136 | )
137 | window_sum = torch.autograd.Variable(
138 | torch.from_numpy(window_sum), requires_grad=False
139 | )
140 | window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
141 | inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
142 | approx_nonzero_indices
143 | ]
144 |
145 | # scale by hop ratio
146 | inverse_transform *= float(self.filter_length) / self.hop_length
147 |
148 | inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
149 | inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
150 |
151 | return inverse_transform
152 |
153 | def forward(self, input_data):
154 | self.magnitude, self.phase = self.transform(input_data)
155 | reconstruction = self.inverse(self.magnitude, self.phase)
156 | return reconstruction
157 |
158 |
159 | class TacotronSTFT(torch.nn.Module):
160 | def __init__(
161 | self,
162 | filter_length=1024,
163 | hop_length=256,
164 | win_length=1024,
165 | n_mel_channels=80,
166 | sampling_rate=22050,
167 | mel_fmin=0.0,
168 | mel_fmax=None,
169 | ):
170 | super(TacotronSTFT, self).__init__()
171 | self.n_mel_channels = n_mel_channels
172 | self.sampling_rate = sampling_rate
173 | self.stft_fn = STFT(filter_length, hop_length, win_length)
174 | mel_basis = librosa_mel_fn(
175 | sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
176 | )
177 | mel_basis = torch.from_numpy(mel_basis).float()
178 | self.register_buffer("mel_basis", mel_basis)
179 |
180 | def spectral_normalize(self, magnitudes):
181 | output = dynamic_range_compression(magnitudes)
182 | return output
183 |
184 | def spectral_de_normalize(self, magnitudes):
185 | output = dynamic_range_decompression(magnitudes)
186 | return output
187 |
188 | def mel_spectrogram(self, y):
189 | """Computes mel-spectrograms from a batch of waves
190 | PARAMS
191 | ------
192 | y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
193 | RETURNS
194 | -------
195 | mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
196 | """
197 | assert torch.min(y.data) >= -1
198 | assert torch.max(y.data) <= 1
199 |
200 | magnitudes, phases = self.stft_fn.transform(y)
201 | magnitudes = magnitudes.data
202 | mel_output = torch.matmul(self.mel_basis, magnitudes)
203 | mel_output = self.spectral_normalize(mel_output)
204 | return mel_output, magnitudes
205 |
--------------------------------------------------------------------------------