├── .gitignore
├── LICENSE
├── README.md
├── api.py
├── checkpoints
└── .keep
├── config.py
├── datas
├── __init__.py
├── dataset.py
└── sampler.py
├── figures
└── structure.jpg
├── filelists
└── example.txt
├── inference.ipynb
├── models
├── __init__.py
├── diffusion_transformer.py
├── duration_predictor.py
├── estimator.py
├── flow_matching.py
├── model.py
├── reference_encoder.py
└── text_encoder.py
├── monotonic_align
├── __init__.py
└── core.py
├── preprocess.py
├── recipes
├── AiSHELL3.py
├── BZNSYP_标贝女声.py
├── VCTK_huggingface.py
├── genshin_en_小虫哥ver.py
├── genshin_zh_小虫哥ver.py
├── hifi_tts.py
└── libriTTS.py
├── requirements.txt
├── text
├── LICENSE
├── __init__.py
├── cleaners.py
├── cn2an
│ ├── __init__.py
│ ├── an2cn.py
│ ├── cn2an.py
│ ├── conf.py
│ └── transform.py
├── cnm3
│ └── ds_CNM3.txt
├── custom_pypinyin_dict
│ ├── __init__.py
│ ├── cc_cedict_0.py
│ ├── cc_cedict_1.py
│ ├── cc_cedict_2.py
│ ├── cc_cedict_3.py
│ ├── genshin.py
│ └── phrase_pinyin_data.py
├── english.py
├── japanese.py
├── mandarin.py
└── symbols.py
├── train.py
├── utils
├── __init__.py
├── audio.py
├── load.py
├── mask.py
└── scheduler.py
├── vocoders
├── __init__.py
├── ffgan
│ ├── __init__.py
│ ├── backbone.py
│ ├── head.py
│ ├── model.py
│ └── unify.py
├── pretrained
│ └── .keep
└── vocos
│ ├── README.md
│ ├── __init__.py
│ ├── config.py
│ ├── dataset.py
│ ├── inference.ipynb
│ ├── models
│ ├── __init__.py
│ ├── backbone.py
│ ├── discriminator.py
│ ├── head.py
│ ├── loss.py
│ ├── model.py
│ └── module.py
│ ├── preprocess.py
│ ├── requirements.txt
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── audio.py
│ ├── load.py
│ └── scheduler.py
└── webui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 KdaiP
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # StableTTS
4 |
5 | Next-generation TTS model using flow-matching and DiT, inspired by [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3).
6 |
7 |
8 |
9 |
10 | ## Introduction
11 |
12 | As the first open-source TTS model that tried to combine flow-matching and DiT, **StableTTS** is a fast and lightweight TTS model for chinese, english and japanese speech generation. It has 31M parameters.
13 |
14 | ✨ **Huggingface demo:** [🤗](https://huggingface.co/spaces/KdaiP/StableTTS1.1)
15 |
16 | ## News
17 |
18 | 2024/10: A new autoregressive TTS model is coming soon...
19 |
20 | 2024/9: 🚀 **StableTTS V1.1 Released** ⭐ Audio quality is largely improved ⭐
21 |
22 | ⭐ **V1.1 Release Highlights:**
23 |
24 | - Fixed critical issues that cause the audio quality being much lower than expected. (Mainly in Mel spectrogram and Attention mask)
25 | - Introduced U-Net-like long skip connections to the DiT in the Flow-matching Decoder.
26 | - Use cosine timestep scheduler from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice)
27 | - Add support for CFG (Classifier-Free Guidance).
28 | - Add support for [FireflyGAN vocoder](https://github.com/fishaudio/vocoder/releases/tag/1.0.0).
29 | - Switched to [torchdiffeq](https://github.com/rtqichen/torchdiffeq) for ODE solvers.
30 | - Improved Chinese text frontend (partially based on [gpt-sovits2](https://github.com/RVC-Boss/GPT-SoVITS)).
31 | - Multilingual support (Chinese, English, Japanese) in a single checkpoint.
32 | - Increased parameters: 10M -> 31M.
33 |
34 |
35 | ## Pretrained models
36 |
37 | ### Text-To-Mel model
38 |
39 | Download and place the model in the `./checkpoints` directory, it is ready for inference, finetuning and webui.
40 |
41 | | Model Name | Task Details | Dataset | Download Link |
42 | |:----------:|:------------:|:-------------:|:-------------:|
43 | | StableTTS | text to mel | 600 hours | [🤗](https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/StableTTS/checkpoint_0.pt)|
44 |
45 | ### Mel-To-Wav model
46 |
47 | Choose a vocoder (`vocos` or `firefly-gan` ) and place it in the `./vocoders/pretrained` directory.
48 |
49 | | Model Name | Task Details | Dataset | Download Link |
50 | |:----------:|:------------:|:-------------:|:-------------:|
51 | | Vocos | mel to wav | 2k hours | [🤗](https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/vocoders/vocos.pt)|
52 | | firefly-gan-base | mel to wav | HiFi-16kh | [download from fishaudio](https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt)|
53 |
54 | ## Installation
55 |
56 | 1. **Install pytorch**: Follow the [official PyTorch guide](https://pytorch.org/get-started/locally/) to install pytorch and torchaudio. We recommend the latest version (tested with PyTorch 2.4 and Python 3.12).
57 |
58 | 2. **Install Dependencies**: Run the following command to install the required Python packages:
59 |
60 | ```bash
61 | pip install -r requirements.txt
62 | ```
63 |
64 | ## Inference
65 |
66 | For detailed inference instructions, please refer to `inference.ipynb`
67 |
68 | We also provide a webui based on gradio, please refer to `webui.py`
69 |
70 | ## Training
71 |
72 | StableTTS is designed to be trained easily. We only need text and audio pairs, without any speaker id or extra feature extraction. Here’s how to get started:
73 |
74 | ### Preparing Your Data
75 |
76 | 1. **Generate Text and Audio pairs**: Generate the text and audio pair filelist as `./filelists/example.txt`. Some recipes of open-source datasets could be found in `./recipes`.
77 |
78 | 2. **Run Preprocessing**: Adjust the `DataConfig` in `preprocess.py` to set your input and output paths, then run the script. This will process the audio and text according to your list, outputting a JSON file with paths to mel features and phonemes.
79 |
80 | **Note: Process multilingual data separately by changing the `language` setting in `DataConfig`**
81 |
82 | ### Start training
83 |
84 | 1. **Adjust Training Configuration**: In `config.py`, modify `TrainConfig` to set your file list path and adjust training parameters (such as batch_size) as needed.
85 |
86 | 2. **Start the Training Process**: Launch `train.py` to start training your model.
87 |
88 | Note: For finetuning, download the pretrained model and place it in the `model_save_path` directory specified in `TrainConfig`. Training script will automatically detect and load the pretrained checkpoint.
89 |
90 | ### (Optional) Vocoder training
91 |
92 | The `./vocoder/vocos` folder contains the training and finetuning codes for vocos vocoder.
93 |
94 | For other types of vocoders, we recommend to train by using [fishaudio vocoder](https://github.com/fishaudio/vocoder): an uniform interface for developing various vocoders. We use the same spectrogram transform so the vocoders trained is compatible with StableTTS.
95 |
96 | ## Model structure
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 | - We use the Diffusion Convolution Transformer block from [Hierspeech++](https://github.com/sh-lee-prml/HierSpeechpp), which is a combination of original [DiT](https://github.com/sh-lee-prml/HierSpeechpp) and [FFT](https://arxiv.org/pdf/1905.09263.pdf)(Feed forward Transformer from fastspeech) for better prosody.
107 |
108 | - In flow-matching decoder, we add a [FiLM layer](https://arxiv.org/abs/1709.07871) before DiT block to condition timestep embedding into model.
109 |
110 | ## References
111 |
112 | The development of our models heavily relies on insights and code from various projects. We express our heartfelt thanks to the creators of the following:
113 |
114 | ### Direct Inspirations
115 |
116 | [Matcha TTS](https://github.com/shivammehta25/Matcha-TTS): Essential flow-matching code.
117 |
118 | [Grad TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): Diffusion model structure.
119 |
120 | [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3): Idea of combining flow-matching and DiT.
121 |
122 | [Vits](https://github.com/jaywalnut310/vits): Code style and MAS insights, DistributedBucketSampler.
123 |
124 | ### Additional References:
125 |
126 | [plowtts-pytorch](https://github.com/p0p4k/pflowtts_pytorch): codes of MAS in training
127 |
128 | [Bert-VITS2](https://github.com/Plachtaa/VITS-fast-fine-tuning) : numba version of MAS and modern pytorch codes of Vits
129 |
130 | [fish-speech](https://github.com/fishaudio/fish-speech): dataclass usage and mel-spectrogram transforms using torchaudio, gradio webui
131 |
132 | [gpt-sovits](https://github.com/RVC-Boss/GPT-SoVITS): melstyle encoder for voice clone
133 |
134 | [coqui xtts](https://huggingface.co/spaces/coqui/xtts): gradio webui
135 |
136 | Chinese Dirtionary Of DiffSinger: [Multi-langs_Dictionary](https://github.com/colstone/Multi-langs_Dictionary) and [atonyxu's fork](https://github.com/atonyxu/Multi-langs_Dictionary)
137 |
138 | ## TODO
139 |
140 | - [x] Release pretrained models.
141 | - [x] Support Japanese language.
142 | - [x] User friendly preprocess and inference script.
143 | - [x] Enhance documentation and citations.
144 | - [x] Release multilingual checkpoint.
145 |
146 | ## Disclaimer
147 |
148 | Any organization or individual is prohibited from using any technology in this repo to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws.
--------------------------------------------------------------------------------
/api.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dataclasses import asdict
5 |
6 | from utils.audio import LogMelSpectrogram
7 | from config import ModelConfig, MelConfig
8 | from models.model import StableTTS
9 |
10 | from text import symbols
11 | from text import cleaned_text_to_sequence
12 | from text.mandarin import chinese_to_cnm3
13 | from text.english import english_to_ipa2
14 | from text.japanese import japanese_to_ipa2
15 |
16 | from datas.dataset import intersperse
17 | from utils.audio import load_and_resample_audio
18 |
19 | def get_vocoder(model_path, model_name='ffgan') -> nn.Module:
20 | if model_name == 'ffgan':
21 | # training or changing ffgan config is not supported in this repo
22 | # you can train your own model at https://github.com/fishaudio/vocoder
23 | from vocoders.ffgan.model import FireflyGANBaseWrapper
24 | vocoder = FireflyGANBaseWrapper(model_path)
25 |
26 | elif model_name == 'vocos':
27 | from vocoders.vocos.models.model import Vocos
28 | from config import VocosConfig, MelConfig
29 | vocoder = Vocos(VocosConfig(), MelConfig())
30 | vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
31 | vocoder.eval()
32 |
33 | else:
34 | raise NotImplementedError(f"Unsupported model: {model_name}")
35 |
36 | return vocoder
37 |
38 | class StableTTSAPI(nn.Module):
39 | def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'):
40 | super().__init__()
41 |
42 | self.mel_config = MelConfig()
43 | self.tts_model_config = ModelConfig()
44 |
45 | self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config))
46 |
47 | # text to mel spectrogram
48 | self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config))
49 | self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True))
50 | self.tts_model.eval()
51 |
52 | # mel spectrogram to waveform
53 | self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name)
54 | self.vocoder_model.eval()
55 |
56 | self.g2p_mapping = {
57 | 'chinese': chinese_to_cnm3,
58 | 'japanese': japanese_to_ipa2,
59 | 'english': english_to_ipa2,
60 | }
61 | self.supported_languages = self.g2p_mapping.keys()
62 |
63 | @ torch.inference_mode()
64 | def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0):
65 | device = next(self.parameters()).device
66 | phonemizer = self.g2p_mapping.get(language)
67 |
68 | text = phonemizer(text)
69 | text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0)
70 | text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device)
71 |
72 | ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device)
73 | ref_audio = self.mel_extractor(ref_audio)
74 |
75 | mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs']
76 | audio_output = self.vocoder_model(mel_output)
77 | return audio_output.cpu(), mel_output.cpu()
78 |
79 | def get_params(self):
80 | tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6
81 | vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6
82 | return tts_param, vocoder_param
83 |
84 | if __name__ == '__main__':
85 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
86 | tts_model_path = './checkpoints/checkpoint_0.pt'
87 | vocoder_model_path = './vocoders/pretrained/vocos.pt'
88 |
89 | model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos')
90 | model.to(device)
91 |
92 | text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……'
93 | audio = './audio_1.wav'
94 |
95 | audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3)
96 | print(audio_output.shape)
97 | print(mel_output.shape)
98 |
99 | import torchaudio
100 | torchaudio.save('output.wav', audio_output, MelConfig().sample_rate)
101 |
102 |
103 |
--------------------------------------------------------------------------------
/checkpoints/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/checkpoints/.keep
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | @dataclass
4 | class MelConfig:
5 | sample_rate: int = 44100
6 | n_fft: int = 2048
7 | win_length: int = 2048
8 | hop_length: int = 512
9 | f_min: float = 0.0
10 | f_max: float = None
11 | pad: int = 0
12 | n_mels: int = 128
13 | center: bool = False
14 | pad_mode: str = "reflect"
15 | mel_scale: str = "slaney"
16 |
17 | def __post_init__(self):
18 | if self.pad == 0:
19 | self.pad = (self.n_fft - self.hop_length) // 2
20 |
21 | @dataclass
22 | class ModelConfig:
23 | hidden_channels: int = 256
24 | filter_channels: int = 1024
25 | n_heads: int = 4
26 | n_enc_layers: int = 3
27 | n_dec_layers: int = 6
28 | kernel_size: int = 3
29 | p_dropout: int = 0.1
30 | gin_channels: int = 256
31 |
32 | @dataclass
33 | class TrainConfig:
34 | train_dataset_path: str = 'filelists/filelist.json'
35 | test_dataset_path: str = 'filelists/filelist.json' # not used
36 | batch_size: int = 32
37 | learning_rate: float = 1e-4
38 | num_epochs: int = 10000
39 | model_save_path: str = './checkpoints'
40 | log_dir: str = './runs'
41 | log_interval: int = 16
42 | save_interval: int = 1
43 | warmup_steps: int = 200
44 |
45 | @dataclass
46 | class VocosConfig:
47 | input_channels: int = 128
48 | dim: int = 512
49 | intermediate_dim: int = 1536
50 | num_layers: int = 8
--------------------------------------------------------------------------------
/datas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/datas/__init__.py
--------------------------------------------------------------------------------
/datas/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import json
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | from text import cleaned_text_to_sequence
9 |
10 | def intersperse(lst: list, item: int):
11 | """
12 | putting a blank token between any two input tokens to improve pronunciation
13 | see https://github.com/jaywalnut310/glow-tts/issues/43 for more details
14 | """
15 | result = [item] * (len(lst) * 2 + 1)
16 | result[1::2] = lst
17 | return result
18 |
19 | class StableDataset(Dataset):
20 | def __init__(self, filelist_path, hop_length):
21 | self.filelist_path = filelist_path
22 | self.hop_length = hop_length
23 |
24 | self._load_filelist(filelist_path)
25 |
26 | def _load_filelist(self, filelist_path):
27 | filelist, lengths = [], []
28 | with open(filelist_path, 'r', encoding='utf-8') as f:
29 | for line in f:
30 | line = json.loads(line.strip())
31 | filelist.append((line['mel_path'], line['phone']))
32 | lengths.append(line['mel_length'])
33 |
34 | self.filelist = filelist
35 | self.lengths = lengths # length is used for DistributedBucketSampler
36 |
37 | def __len__(self):
38 | return len(self.filelist)
39 |
40 | def __getitem__(self, idx):
41 | mel_path, phone = self.filelist[idx]
42 | mel = torch.load(mel_path, map_location='cpu', weights_only=True)
43 | phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long)
44 | return mel, phone
45 |
46 | def collate_fn(batch):
47 | texts = [item[1] for item in batch]
48 | mels = [item[0] for item in batch]
49 | mels_sliced = [random_slice_tensor(mel) for mel in mels]
50 |
51 | text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long)
52 | mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long)
53 | mels_sliced_lengths = torch.tensor([mel_sliced.size(-1) for mel_sliced in mels_sliced], dtype=torch.long)
54 |
55 | # pad to the same length
56 | texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0)
57 | mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0)
58 | mels_sliced_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels_sliced), padding=0)
59 |
60 | return texts_padded, text_lengths, mels_padded, mel_lengths, mels_sliced_padded, mels_sliced_lengths
61 |
62 | # random slice mel for reference encoder to prevent overfitting
63 | def random_slice_tensor(x: torch.Tensor):
64 | length = x.size(-1)
65 | if length < 12:
66 | return x
67 | segmnt_size = random.randint(length // 12, length // 3)
68 | start = random.randint(0, length - segmnt_size)
69 | return x[..., start : start + segmnt_size]
--------------------------------------------------------------------------------
/datas/sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py
4 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
5 | """
6 | Maintain similar input lengths in a batch.
7 | Length groups are specified by boundaries.
8 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
9 |
10 | It removes samples which are not included in the boundaries.
11 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
12 | """
13 |
14 | def __init__(
15 | self,
16 | dataset,
17 | batch_size,
18 | boundaries,
19 | num_replicas=None,
20 | rank=None,
21 | shuffle=True,
22 | ):
23 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
24 | self.lengths = dataset.lengths
25 | self.batch_size = batch_size
26 | self.boundaries = boundaries
27 |
28 | self.buckets, self.num_samples_per_bucket = self._create_buckets()
29 | self.total_size = sum(self.num_samples_per_bucket)
30 | self.num_samples = self.total_size // self.num_replicas
31 |
32 | def _create_buckets(self):
33 | buckets = [[] for _ in range(len(self.boundaries) - 1)]
34 | for i in range(len(self.lengths)):
35 | length = self.lengths[i]
36 | idx_bucket = self._bisect(length)
37 | if idx_bucket != -1:
38 | buckets[idx_bucket].append(i)
39 |
40 | # from https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/data_utils.py
41 | # avoid "integer division or modulo by zero" error for very small dataset
42 | # see https://github.com/Plachtaa/VITS-fast-fine-tuning/pull/228 for more details
43 | try:
44 | for i in range(len(buckets) - 1, 0, -1):
45 | if len(buckets[i]) == 0:
46 | buckets.pop(i)
47 | self.boundaries.pop(i + 1)
48 | assert all(len(bucket) > 0 for bucket in buckets)
49 | # When one bucket is not traversed
50 | except Exception as e:
51 | print('Bucket warning ', e)
52 | for i in range(len(buckets) - 1, -1, -1):
53 | if len(buckets[i]) == 0:
54 | buckets.pop(i)
55 | self.boundaries.pop(i + 1)
56 |
57 | num_samples_per_bucket = []
58 | for i in range(len(buckets)):
59 | len_bucket = len(buckets[i])
60 | total_batch_size = self.num_replicas * self.batch_size
61 | rem = (
62 | total_batch_size - (len_bucket % total_batch_size)
63 | ) % total_batch_size
64 | num_samples_per_bucket.append(len_bucket + rem)
65 | return buckets, num_samples_per_bucket
66 |
67 | def __iter__(self):
68 | # deterministically shuffle based on epoch
69 | g = torch.Generator()
70 | g.manual_seed(self.epoch)
71 |
72 | indices = []
73 | if self.shuffle:
74 | for bucket in self.buckets:
75 | indices.append(torch.randperm(len(bucket), generator=g).tolist())
76 | else:
77 | for bucket in self.buckets:
78 | indices.append(list(range(len(bucket))))
79 |
80 | batches = []
81 | for i in range(len(self.buckets)):
82 | bucket = self.buckets[i]
83 | len_bucket = len(bucket)
84 | ids_bucket = indices[i]
85 | num_samples_bucket = self.num_samples_per_bucket[i]
86 |
87 | # add extra samples to make it evenly divisible
88 | rem = num_samples_bucket - len_bucket
89 | ids_bucket = (
90 | ids_bucket
91 | + ids_bucket * (rem // len_bucket)
92 | + ids_bucket[: (rem % len_bucket)]
93 | )
94 |
95 | # subsample
96 | ids_bucket = ids_bucket[self.rank :: self.num_replicas]
97 |
98 | # batching
99 | for j in range(len(ids_bucket) // self.batch_size):
100 | batch = [
101 | bucket[idx]
102 | for idx in ids_bucket[
103 | j * self.batch_size : (j + 1) * self.batch_size
104 | ]
105 | ]
106 | batches.append(batch)
107 |
108 | if self.shuffle:
109 | batch_ids = torch.randperm(len(batches), generator=g).tolist()
110 | batches = [batches[i] for i in batch_ids]
111 | self.batches = batches
112 |
113 | assert len(self.batches) * self.batch_size == self.num_samples
114 | return iter(self.batches)
115 |
116 | def _bisect(self, x, lo=0, hi=None):
117 | if hi is None:
118 | hi = len(self.boundaries) - 1
119 |
120 | if hi > lo:
121 | mid = (hi + lo) // 2
122 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
123 | return mid
124 | elif x <= self.boundaries[mid]:
125 | return self._bisect(x, lo, mid)
126 | else:
127 | return self._bisect(x, mid + 1, hi)
128 | else:
129 | return -1
130 |
131 | def __len__(self):
132 | return self.num_samples // self.batch_size
--------------------------------------------------------------------------------
/figures/structure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/figures/structure.jpg
--------------------------------------------------------------------------------
/filelists/example.txt:
--------------------------------------------------------------------------------
1 | ./audio1.wav|你好,世界。
2 | ./audio2.wav|Hello, world.
--------------------------------------------------------------------------------
/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "from IPython.display import Audio, display\n",
10 | "import torch\n",
11 | "\n",
12 | "from api import StableTTSAPI\n",
13 | "\n",
14 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
15 | "\n",
16 | "tts_model_path = './checkpoints/checkpoint_0.pt' # path to StableTTS checkpoint\n",
17 | "vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' # path to vocoder checkpoint\n",
18 | "vocoder_type = 'ffgan' # ffgan or vocos\n",
19 | "\n",
20 | "# vocoder_model_path = './vocoders/pretrained/vocos.pt'\n",
21 | "# vocoder_type = 'vocos'\n",
22 | "\n",
23 | "model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type)\n",
24 | "model.to(device)\n",
25 | "\n",
26 | "tts_param, vocoder_param = model.get_params()\n",
27 | "print(f'tts_param: {tts_param}, vocoder_param: {vocoder_param}')"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": null,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "text = '你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!'\n",
37 | "ref_audio = './audio_1.wav'\n",
38 | "language = 'chinese' # support chinese, japanese and english\n",
39 | "solver = 'dopri5' # recommend using euler, midpoint or dopri5\n",
40 | "steps = 30\n",
41 | "cfg = 3 # recommend 1-4\n",
42 | "\n",
43 | "audio_output, mel_output = model.inference(text, ref_audio, language, steps, 1, 1, solver, cfg)\n",
44 | "\n",
45 | "display(Audio(ref_audio))\n",
46 | "display(Audio(audio_output, rate=model.mel_config.sample_rate))"
47 | ]
48 | }
49 | ],
50 | "metadata": {
51 | "kernelspec": {
52 | "display_name": "lxn_vits",
53 | "language": "python",
54 | "name": "python3"
55 | },
56 | "language_info": {
57 | "codemirror_mode": {
58 | "name": "ipython",
59 | "version": 3
60 | },
61 | "file_extension": ".py",
62 | "mimetype": "text/x-python",
63 | "name": "python",
64 | "nbconvert_exporter": "python",
65 | "pygments_lexer": "ipython3",
66 | "version": "3.11.8"
67 | }
68 | },
69 | "nbformat": 4,
70 | "nbformat_minor": 2
71 | }
72 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/models/__init__.py
--------------------------------------------------------------------------------
/models/diffusion_transformer.py:
--------------------------------------------------------------------------------
1 | # References:
2 | # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py
3 | # https://github.com/jaywalnut310/vits/blob/main/attentions.py
4 | # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | class FFN(nn.Module):
11 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0):
12 | super().__init__()
13 | self.in_channels = in_channels
14 | self.out_channels = out_channels
15 | self.filter_channels = filter_channels
16 | self.kernel_size = kernel_size
17 | self.p_dropout = p_dropout
18 | self.gin_channels = gin_channels
19 |
20 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
21 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2)
22 | self.drop = nn.Dropout(p_dropout)
23 | self.act1 = nn.SiLU(inplace=True)
24 |
25 | def forward(self, x, x_mask):
26 | x = self.conv_1(x * x_mask)
27 | x = self.act1(x)
28 | x = self.drop(x)
29 | x = self.conv_2(x * x_mask)
30 | return x * x_mask
31 |
32 | class MultiHeadAttention(nn.Module):
33 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.):
34 | super().__init__()
35 | assert channels % n_heads == 0
36 |
37 | self.channels = channels
38 | self.out_channels = out_channels
39 | self.n_heads = n_heads
40 | self.p_dropout = p_dropout
41 |
42 | self.k_channels = channels // n_heads
43 | self.conv_q = torch.nn.Conv1d(channels, channels, 1)
44 | self.conv_k = torch.nn.Conv1d(channels, channels, 1)
45 | self.conv_v = torch.nn.Conv1d(channels, channels, 1)
46 |
47 | # from https://nn.labml.ai/transformers/rope/index.html
48 | self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
49 | self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5)
50 |
51 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
52 | self.drop = torch.nn.Dropout(p_dropout)
53 |
54 | torch.nn.init.xavier_uniform_(self.conv_q.weight)
55 | torch.nn.init.xavier_uniform_(self.conv_k.weight)
56 | torch.nn.init.xavier_uniform_(self.conv_v.weight)
57 |
58 | def forward(self, x, attn_mask=None):
59 | q = self.conv_q(x)
60 | k = self.conv_k(x)
61 | v = self.conv_v(x)
62 |
63 | x = self.attention(q, k, v, mask=attn_mask)
64 |
65 | x = self.conv_o(x)
66 | return x
67 |
68 | def attention(self, query, key, value, mask=None):
69 | b, d, t_s, t_t = (*key.size(), query.size(2))
70 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
71 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
72 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
73 |
74 | query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head]
75 | key = self.key_rotary_pe(key)
76 |
77 | output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0)
78 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
79 | return output
80 |
81 | # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390
82 | class DiTConVBlock(nn.Module):
83 | """
84 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
85 | """
86 | def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0):
87 | super().__init__()
88 | self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
89 | self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout)
90 | self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False)
91 | self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
92 | self.adaLN_modulation = nn.Sequential(
93 | nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(),
94 | nn.SiLU(),
95 | nn.Linear(hidden_channels, 6 * hidden_channels, bias=True)
96 | )
97 |
98 | def forward(self, x, c, x_mask):
99 | """
100 | Args:
101 | x : [batch_size, channel, time]
102 | c : [batch_size, channel]
103 | x_mask : [batch_size, 1, time]
104 | return the same shape as x
105 | """
106 | x = x * x_mask
107 | attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time]
108 | attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max)
109 |
110 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1]
111 | x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask
112 | x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask)
113 |
114 | # no condition version
115 | # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask)
116 | # x = x + self.mlp(self.norm2(x.transpose(1,2)).transpose(1,2), x_mask)
117 | return x
118 |
119 | @staticmethod
120 | def modulate(x, shift, scale):
121 | return x * (1 + scale) + shift
122 |
123 | class RotaryPositionalEmbeddings(nn.Module):
124 | """
125 | ## RoPE module
126 |
127 | Rotary encoding transforms pairs of features by rotating in the 2D plane.
128 | That is, it organizes the $d$ features as $\frac{d}{2}$ pairs.
129 | Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it
130 | by an angle depending on the position of the token.
131 | """
132 |
133 | def __init__(self, d: int, base: int = 10_000):
134 | r"""
135 | * `d` is the number of features $d$
136 | * `base` is the constant used for calculating $\Theta$
137 | """
138 | super().__init__()
139 |
140 | self.base = base
141 | self.d = int(d)
142 | self.cos_cached = None
143 | self.sin_cached = None
144 |
145 | def _build_cache(self, x: torch.Tensor):
146 | r"""
147 | Cache $\cos$ and $\sin$ values
148 | """
149 | # Return if cache is already built
150 | if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
151 | return
152 |
153 | # Get sequence length
154 | seq_len = x.shape[0]
155 |
156 | # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
157 | theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device)
158 |
159 | # Create position indexes `[0, 1, ..., seq_len - 1]`
160 | seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device)
161 |
162 | # Calculate the product of position index and $\theta_i$
163 | idx_theta = torch.einsum("n,d->nd", seq_idx, theta)
164 |
165 | # Concatenate so that for row $m$ we have
166 | # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$
167 | idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)
168 |
169 | # Cache them
170 | self.cos_cached = idx_theta2.cos()[:, None, None, :]
171 | self.sin_cached = idx_theta2.sin()[:, None, None, :]
172 |
173 | def _neg_half(self, x: torch.Tensor):
174 | # $\frac{d}{2}$
175 | d_2 = self.d // 2
176 |
177 | # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
178 | return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
179 |
180 | def forward(self, x: torch.Tensor):
181 | """
182 | * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]`
183 | """
184 | # Cache $\cos$ and $\sin$ values
185 | x = x.permute(2, 0, 1, 3) # b h t d -> t b h d
186 |
187 | self._build_cache(x)
188 |
189 | # Split the features, we can choose to apply rotary embeddings only to a partial set of features.
190 | x_rope, x_pass = x[..., : self.d], x[..., self.d :]
191 |
192 | # Calculate
193 | # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$
194 | neg_half_x = self._neg_half(x_rope)
195 |
196 | x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]])
197 |
198 | return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d
199 |
200 | class Transpose(nn.Identity):
201 | """(N, T, D) -> (N, D, T)"""
202 |
203 | def forward(self, input: torch.Tensor) -> torch.Tensor:
204 | return input.transpose(1, 2)
205 |
206 |
--------------------------------------------------------------------------------
/models/duration_predictor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | # modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98
5 | class DurationPredictor(nn.Module):
6 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
7 | super().__init__()
8 |
9 | self.in_channels = in_channels
10 | self.filter_channels = filter_channels
11 | self.kernel_size = kernel_size
12 | self.p_dropout = p_dropout
13 | self.gin_channels = gin_channels
14 |
15 | self.drop = nn.Dropout(p_dropout)
16 | self.conv1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2)
17 | self.norm1 = nn.LayerNorm(filter_channels)
18 | self.conv2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2)
19 | self.norm2 = nn.LayerNorm(filter_channels)
20 | self.proj = nn.Conv1d(filter_channels, 1, 1)
21 |
22 | self.cond = nn.Conv1d(gin_channels, in_channels, 1)
23 |
24 | def forward(self, x, x_mask, g):
25 | x = x.detach()
26 | x = x + self.cond(g.unsqueeze(2).detach())
27 | x = self.conv1(x * x_mask)
28 | x = torch.relu(x)
29 | x = self.norm1(x.transpose(1,2)).transpose(1,2)
30 | x = self.drop(x)
31 | x = self.conv2(x * x_mask)
32 | x = torch.relu(x)
33 | x = self.norm2(x.transpose(1,2)).transpose(1,2)
34 | x = self.drop(x)
35 | x = self.proj(x * x_mask)
36 | return x * x_mask
37 |
38 | def duration_loss(logw, logw_, lengths):
39 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
40 | return loss
--------------------------------------------------------------------------------
/models/estimator.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from models.diffusion_transformer import DiTConVBlock
7 |
8 | class DitWrapper(nn.Module):
9 | """ add FiLM layer to condition time embedding to DiT """
10 | def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0):
11 | super().__init__()
12 | self.time_fusion = FiLMLayer(hidden_channels, time_channels)
13 | self.block = DiTConVBlock(hidden_channels, filter_channels, num_heads, kernel_size, p_dropout, gin_channels)
14 |
15 | def forward(self, x, c, t, x_mask):
16 | x = self.time_fusion(x, t) * x_mask
17 | x = self.block(x, c, x_mask)
18 | return x
19 |
20 | class FiLMLayer(nn.Module):
21 | """
22 | Feature-wise Linear Modulation (FiLM) layer
23 | Reference: https://arxiv.org/abs/1709.07871
24 | """
25 | def __init__(self, in_channels, cond_channels):
26 |
27 | super(FiLMLayer, self).__init__()
28 | self.in_channels = in_channels
29 | self.film = nn.Conv1d(cond_channels, in_channels * 2, 1)
30 |
31 | def forward(self, x, c):
32 | gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1)
33 | return gamma * x + beta
34 |
35 | class SinusoidalPosEmb(nn.Module):
36 | def __init__(self, dim):
37 | super().__init__()
38 | self.dim = dim
39 | assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
40 |
41 | def forward(self, x, scale=1000):
42 | if x.ndim < 1:
43 | x = x.unsqueeze(0)
44 | half_dim = self.dim // 2
45 | emb = math.log(10000) / (half_dim - 1)
46 | emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb)
47 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
49 | return emb
50 |
51 | class TimestepEmbedding(nn.Module):
52 | def __init__(self, in_channels, out_channels, filter_channels):
53 | super().__init__()
54 |
55 | self.layer = nn.Sequential(
56 | nn.Linear(in_channels, filter_channels),
57 | nn.SiLU(inplace=True),
58 | nn.Linear(filter_channels, out_channels)
59 | )
60 |
61 | def forward(self, x):
62 | return self.layer(x)
63 |
64 | # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py
65 | class Decoder(nn.Module):
66 | def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, dropout=0.1, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0, use_lsc=True):
67 | super().__init__()
68 | self.noise_channels = noise_channels
69 | self.cond_channels = cond_channels
70 | self.hidden_channels = hidden_channels
71 | self.out_channels = out_channels
72 | self.filter_channels = filter_channels
73 | self.use_lsc = use_lsc # whether to use unet-like long skip connection
74 |
75 | self.time_embeddings = SinusoidalPosEmb(hidden_channels)
76 | self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels)
77 |
78 | self.in_proj = nn.Conv1d(hidden_channels + noise_channels, hidden_channels, 1) # cat noise and encoder output as input
79 | self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)])
80 | self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1)
81 |
82 | # prenet for encoder output
83 | self.cond_proj = nn.Sequential(
84 | nn.Conv1d(cond_channels, filter_channels, kernel_size, padding=kernel_size//2),
85 | nn.SiLU(inplace=True),
86 | nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), # add about 3M params
87 | nn.SiLU(inplace=True),
88 | nn.Conv1d(filter_channels, hidden_channels, kernel_size, padding=kernel_size//2)
89 | )
90 |
91 | if use_lsc:
92 | assert n_layers % 2 == 0
93 | self.n_lsc_layers = n_layers // 2
94 | self.lsc_layers = nn.ModuleList([nn.Conv1d(hidden_channels + hidden_channels, hidden_channels, kernel_size, padding = kernel_size // 2) for _ in range(self.n_lsc_layers)])
95 |
96 | self.initialize_weights()
97 |
98 | def initialize_weights(self):
99 | for block in self.blocks:
100 | nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0)
101 | nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0)
102 |
103 | def forward(self, t, x, mask, mu, c):
104 | """Forward pass of the DiT model.
105 |
106 | Args:
107 | t (torch.Tensor): timestep, shape (batch_size)
108 | x (torch.Tensor): noise, shape (batch_size, in_channels, time)
109 | mask (torch.Tensor): shape (batch_size, 1, time)
110 | mu (torch.Tensor): output of encoder, shape (batch_size, in_channels, time)
111 | c (torch.Tensor): shape (batch_size, gin_channels)
112 |
113 | Returns:
114 | _type_: _description_
115 | """
116 |
117 | t = self.time_mlp(self.time_embeddings(t))
118 | mu = self.cond_proj(mu)
119 |
120 | x = torch.cat((x, mu), dim=1)
121 | x = self.in_proj(x)
122 |
123 | lsc_outputs = [] if self.use_lsc else None
124 |
125 | for idx, block in enumerate(self.blocks):
126 | # add long skip connection, see https://arxiv.org/pdf/2209.12152 for more details
127 | if self.use_lsc:
128 | if idx < self.n_lsc_layers:
129 | lsc_outputs.append(x)
130 | else:
131 | x = torch.cat((x, lsc_outputs.pop()), dim=1)
132 | x = self.lsc_layers[idx - self.n_lsc_layers](x)
133 |
134 | x = block(x, c, t, mask)
135 |
136 | output = self.final_proj(x * mask)
137 |
138 | return output * mask
--------------------------------------------------------------------------------
/models/flow_matching.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import functools
6 | from torchdiffeq import odeint
7 |
8 | from models.estimator import Decoder
9 |
10 | # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py
11 | class CFMDecoder(torch.nn.Module):
12 | def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
13 | super().__init__()
14 | self.noise_channels = noise_channels
15 | self.cond_channels = cond_channels
16 | self.hidden_channels = hidden_channels
17 | self.out_channels = out_channels
18 | self.filter_channels = filter_channels
19 | self.gin_channels = gin_channels
20 | self.sigma_min = 1e-4
21 |
22 | self.estimator = Decoder(noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels)
23 |
24 | @torch.inference_mode()
25 | def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None, solver=None, cfg_kwargs=None):
26 | """Forward diffusion
27 |
28 | Args:
29 | mu (torch.Tensor): output of encoder
30 | shape: (batch_size, n_feats, mel_timesteps)
31 | mask (torch.Tensor): output_mask
32 | shape: (batch_size, 1, mel_timesteps)
33 | n_timesteps (int): number of diffusion steps
34 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
35 | c (torch.Tensor, optional): speaker embedding
36 | shape: (batch_size, gin_channels)
37 | solver: see https://github.com/rtqichen/torchdiffeq for supported solvers
38 | cfg_kwargs: used for cfg inference
39 |
40 | Returns:
41 | sample: generated mel-spectrogram
42 | shape: (batch_size, n_feats, mel_timesteps)
43 | """
44 |
45 | z = torch.randn_like(mu) * temperature
46 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
47 |
48 | # cfg control
49 | if cfg_kwargs is None:
50 | estimator = functools.partial(self.estimator, mask=mask, mu=mu, c=c)
51 | else:
52 | estimator = functools.partial(self.cfg_wrapper, mask=mask, mu=mu, c=c, cfg_kwargs=cfg_kwargs)
53 |
54 | trajectory = odeint(estimator, z, t_span, method=solver, rtol=1e-5, atol=1e-5)
55 | return trajectory[-1]
56 |
57 | # cfg inference
58 | def cfg_wrapper(self, t, x, mask, mu, c, cfg_kwargs):
59 | fake_speaker = cfg_kwargs['fake_speaker'].repeat(x.size(0), 1)
60 | fake_content = cfg_kwargs['fake_content'].repeat(x.size(0), 1, x.size(-1))
61 | cfg_strength = cfg_kwargs['cfg_strength']
62 |
63 | cond_output = self.estimator(t, x, mask, mu, c)
64 | uncond_output = self.estimator(t, x, mask, fake_content, fake_speaker)
65 |
66 | output = uncond_output + cfg_strength * (cond_output - uncond_output)
67 | return output
68 |
69 | def compute_loss(self, x1, mask, mu, c):
70 | """Computes diffusion loss
71 |
72 | Args:
73 | x1 (torch.Tensor): Target
74 | shape: (batch_size, n_feats, mel_timesteps)
75 | mask (torch.Tensor): target mask
76 | shape: (batch_size, 1, mel_timesteps)
77 | mu (torch.Tensor): output of encoder
78 | shape: (batch_size, n_feats, mel_timesteps)
79 | c (torch.Tensor, optional): speaker condition.
80 |
81 | Returns:
82 | loss: conditional flow matching loss
83 | y: conditional flow
84 | shape: (batch_size, n_feats, mel_timesteps)
85 | """
86 | b, _, t = mu.shape
87 |
88 | # random timestep
89 | # use cosine timestep scheduler from cosyvoice: https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/flow/flow_matching.py
90 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
91 | t = 1 - torch.cos(t * 0.5 * torch.pi)
92 |
93 | # sample noise p(x_0)
94 | z = torch.randn_like(x1)
95 |
96 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1
97 | u = x1 - (1 - self.sigma_min) * z
98 |
99 | loss = F.mse_loss(self.estimator(t.squeeze(), y, mask, mu, c), u, reduction="sum") / (torch.sum(mask) * u.size(1))
100 | return loss, y
101 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 |
5 | import monotonic_align
6 | from models.text_encoder import TextEncoder
7 | from models.flow_matching import CFMDecoder
8 | from models.reference_encoder import MelStyleEncoder
9 | from models.duration_predictor import DurationPredictor, duration_loss
10 | from utils.mask import sequence_mask
11 |
12 | def convert_pad_shape(pad_shape):
13 | inverted_shape = pad_shape[::-1]
14 | pad_shape = [item for sublist in inverted_shape for item in sublist]
15 | return pad_shape
16 |
17 | def generate_path(duration, mask):
18 | b, t_x, t_y = mask.shape
19 | cum_duration = torch.cumsum(duration, 1)
20 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype, device=duration.device)
21 |
22 | cum_duration_flat = cum_duration.view(b * t_x)
23 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
24 | path = path.view(b, t_x, t_y)
25 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
26 | path = path * mask
27 | return path
28 |
29 | # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
30 | class StableTTS(nn.Module):
31 | def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
32 | super().__init__()
33 |
34 | self.n_vocab = n_vocab
35 | self.mel_channels = mel_channels
36 |
37 | self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
38 | self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=5, dropout=0.25)
39 | self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, 0.5, gin_channels)
40 | self.decoder = CFMDecoder(mel_channels, mel_channels, hidden_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
41 |
42 | # uncondition input for cfg
43 | self.fake_speaker = nn.Parameter(torch.zeros(1, gin_channels))
44 | self.fake_content = nn.Parameter(torch.zeros(1, mel_channels, 1))
45 |
46 | self.cfg_dropout = 0.2
47 |
48 | @torch.inference_mode()
49 | def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0, solver=None, cfg=1.0):
50 | """
51 | Generates mel-spectrogram from text. Returns:
52 | 1. encoder outputs
53 | 2. decoder outputs
54 | 3. generated alignment
55 |
56 | Args:
57 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
58 | shape: (batch_size, max_text_length)
59 | x_lengths (torch.Tensor): lengths of texts in batch.
60 | shape: (batch_size,)
61 | n_timesteps (int): number of steps to use for reverse diffusion in decoder.
62 | temperature (float, optional): controls variance of terminal distribution.
63 | y (torch.Tensor): mel spectrogram of reference audio
64 | shape: (batch_size, mel_channels, time)
65 | length_scale (float, optional): controls speech pace.
66 | Increase value to slow down generated speech and vice versa.
67 |
68 | Returns:
69 | dict: {
70 | "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
71 | # Average mel spectrogram generated by the encoder
72 | "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
73 | # Refined mel spectrogram improved by the CFM
74 | "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
75 | # Alignment map between text and mel spectrogram
76 | """
77 |
78 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
79 | c = self.ref_encoder(y, None)
80 | x, mu_x, x_mask = self.encoder(x, c, x_lengths)
81 | logw = self.dp(x, x_mask, c)
82 |
83 | w = torch.exp(logw) * x_mask
84 | w_ceil = torch.ceil(w) * length_scale
85 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
86 | y_max_length = y_lengths.max()
87 |
88 | # Using obtained durations `w` construct alignment map `attn`
89 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
90 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
91 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
92 |
93 | # Align encoded text and get mu_y
94 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
95 | mu_y = mu_y.transpose(1, 2)
96 | encoder_outputs = mu_y[:, :, :y_max_length]
97 |
98 | # Generate sample tracing the probability flow
99 | if cfg == 1.0:
100 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver)
101 | else:
102 | cfg_kwargs = {'fake_speaker': self.fake_speaker, 'fake_content': self.fake_content, 'cfg_strength': cfg}
103 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver, cfg_kwargs)
104 |
105 | decoder_outputs = decoder_outputs[:, :, :y_max_length]
106 |
107 |
108 | return {
109 | "encoder_outputs": encoder_outputs,
110 | "decoder_outputs": decoder_outputs,
111 | "attn": attn[:, :, :y_max_length],
112 | }
113 |
114 | def forward(self, x, x_lengths, y, y_lengths, z, z_lengths):
115 | """
116 | Computes 3 losses:
117 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
118 | 2. prior loss: loss between mel-spectrogram and encoder outputs.
119 | 3. flow matching loss: loss between mel-spectrogram and decoder outputs.
120 |
121 | Args:
122 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
123 | shape: (batch_size, max_text_length)
124 | x_lengths (torch.Tensor): lengths of texts in batch.
125 | shape: (batch_size,)
126 | y (torch.Tensor): batch of corresponding mel-spectrograms.
127 | shape: (batch_size, n_feats, max_mel_length)
128 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
129 | shape: (batch_size,)
130 | z (torch.Tensor): batch of cliced mel-spectrograms.
131 | shape: (batch_size, n_feats, max_mel_length)
132 | z_lengths (torch.Tensor): lengths of sliced mel-spectrograms in batch.
133 | shape: (batch_size,)
134 | """
135 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw`
136 | y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
137 | z_mask = sequence_mask(z_lengths, z.size(2)).unsqueeze(1).to(z.dtype)
138 | cfg_mask = torch.rand(y.size(0), 1, device=y.device) > self.cfg_dropout
139 |
140 | # compute global speaker embedding
141 | c = self.ref_encoder(z, z_mask) * cfg_mask + ~cfg_mask * self.fake_speaker.repeat(z.size(0), 1)
142 |
143 | x, mu_x, x_mask = self.encoder(x, c, x_lengths)
144 | logw = self.dp(x, x_mask, c)
145 |
146 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
147 |
148 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram
149 | with torch.no_grad():
150 | s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
151 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True)
152 | neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
153 | neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
154 | neg_cent4 = torch.sum(-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True)
155 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
156 |
157 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
158 | attn = (monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach())
159 |
160 | # Compute loss between predicted log-scaled durations and those obtained from MAS
161 | # refered to as prior loss in the paper
162 | logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
163 | dur_loss = duration_loss(logw, logw_, x_lengths)
164 |
165 | # Align encoded text with mel-spectrogram and get mu_y segment
166 | attn = attn.squeeze(1).transpose(1,2)
167 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
168 | mu_y = mu_y.transpose(1, 2)
169 |
170 | # Compute loss of the decoder
171 | cfg_mask = cfg_mask.unsqueeze(-1)
172 | mu_y_masked = mu_y * cfg_mask + ~cfg_mask * self.fake_content.repeat(mu_y.size(0), 1, mu_y.size(-1)) # mask content information for better diversity for flow-matching
173 | diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y_masked, c)
174 |
175 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
176 | prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
177 |
178 | return dur_loss, diff_loss, prior_loss, attn
--------------------------------------------------------------------------------
/models/reference_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Conv1dGLU(nn.Module):
5 | """
6 | Conv1d + GLU(Gated Linear Unit) with residual connection.
7 | For GLU refer to https://arxiv.org/abs/1612.08083 paper.
8 | """
9 |
10 | def __init__(self, in_channels, out_channels, kernel_size, dropout):
11 | super(Conv1dGLU, self).__init__()
12 | self.out_channels = out_channels
13 | self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
14 | self.dropout = nn.Dropout(dropout)
15 |
16 | def forward(self, x):
17 | residual = x
18 | x = self.conv1(x)
19 | x1, x2 = torch.split(x, self.out_channels, dim=1)
20 | x = x1 * torch.sigmoid(x2)
21 | x = residual + self.dropout(x)
22 | return x
23 |
24 | # modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766
25 | class MelStyleEncoder(nn.Module):
26 | """MelStyleEncoder"""
27 |
28 | def __init__(
29 | self,
30 | n_mel_channels=80,
31 | style_hidden=128,
32 | style_vector_dim=256,
33 | style_kernel_size=5,
34 | style_head=2,
35 | dropout=0.1,
36 | ):
37 | super(MelStyleEncoder, self).__init__()
38 | self.in_dim = n_mel_channels
39 | self.hidden_dim = style_hidden
40 | self.out_dim = style_vector_dim
41 | self.kernel_size = style_kernel_size
42 | self.n_head = style_head
43 | self.dropout = dropout
44 |
45 | self.spectral = nn.Sequential(
46 | nn.Linear(self.in_dim, self.hidden_dim),
47 | nn.Mish(inplace=True),
48 | nn.Dropout(self.dropout),
49 | nn.Linear(self.hidden_dim, self.hidden_dim),
50 | nn.Mish(inplace=True),
51 | nn.Dropout(self.dropout),
52 | )
53 |
54 | self.temporal = nn.Sequential(
55 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
56 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
57 | )
58 |
59 | self.slf_attn = nn.MultiheadAttention(
60 | self.hidden_dim,
61 | self.n_head,
62 | self.dropout,
63 | batch_first=True
64 | )
65 |
66 | self.fc = nn.Linear(self.hidden_dim, self.out_dim)
67 |
68 | def temporal_avg_pool(self, x, mask=None):
69 | if mask is None:
70 | return torch.mean(x, dim=1)
71 | else:
72 | return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
73 |
74 | def forward(self, x, x_mask=None):
75 | x = x.transpose(1, 2)
76 |
77 | # spectral
78 | x = self.spectral(x)
79 | # temporal
80 | x = x.transpose(1, 2)
81 | x = self.temporal(x)
82 | x = x.transpose(1, 2)
83 | # self-attention
84 | if x_mask is not None:
85 | x_mask = ~x_mask.squeeze(1).to(torch.bool)
86 | x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask, need_weights=False)
87 | # fc
88 | x = self.fc(x)
89 | # temoral average pooling
90 | w = self.temporal_avg_pool(x, mask=x_mask)
91 |
92 | return w
93 |
94 | # Attention Pool version of MelStyleEncoder, not used
95 | class AttnMelStyleEncoder(nn.Module):
96 | """MelStyleEncoder"""
97 |
98 | def __init__(
99 | self,
100 | n_mel_channels=80,
101 | style_hidden=128,
102 | style_vector_dim=256,
103 | style_kernel_size=5,
104 | style_head=2,
105 | dropout=0.1,
106 | ):
107 | super().__init__()
108 | self.in_dim = n_mel_channels
109 | self.hidden_dim = style_hidden
110 | self.out_dim = style_vector_dim
111 | self.kernel_size = style_kernel_size
112 | self.n_head = style_head
113 | self.dropout = dropout
114 |
115 | self.spectral = nn.Sequential(
116 | nn.Linear(self.in_dim, self.hidden_dim),
117 | nn.Mish(inplace=True),
118 | nn.Dropout(self.dropout),
119 | nn.Linear(self.hidden_dim, self.hidden_dim),
120 | nn.Mish(inplace=True),
121 | nn.Dropout(self.dropout),
122 | )
123 |
124 | self.temporal = nn.Sequential(
125 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
126 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout),
127 | )
128 |
129 | self.slf_attn = nn.MultiheadAttention(
130 | self.hidden_dim,
131 | self.n_head,
132 | self.dropout,
133 | batch_first=True
134 | )
135 |
136 | self.fc = nn.Linear(self.hidden_dim, self.out_dim)
137 |
138 | def temporal_avg_pool(self, x, mask=None):
139 | if mask is None:
140 | return torch.mean(x, dim=1)
141 | else:
142 | return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1)
143 |
144 | def forward(self, x, x_mask=None):
145 | x = x.transpose(1, 2)
146 |
147 | # spectral
148 | x = self.spectral(x)
149 | # temporal
150 | x = x.transpose(1, 2)
151 | x = self.temporal(x)
152 | x = x.transpose(1, 2)
153 | # self-attention
154 | if x_mask is not None:
155 | x_mask = ~x_mask.squeeze(1).to(torch.bool)
156 | zeros = torch.zeros(x_mask.size(0), 1, device=x_mask.device, dtype=x_mask.dtype)
157 | x_attn_mask = torch.cat((zeros, x_mask), dim=1)
158 | else:
159 | x_attn_mask = None
160 |
161 | avg = self.temporal_avg_pool(x, x_mask).unsqueeze(1)
162 | x = torch.cat([avg, x], dim=1)
163 | x, _ = self.slf_attn(x, x, x, key_padding_mask=x_attn_mask, need_weights=False)
164 | x = x[:, 0, :]
165 | # fc
166 | x = self.fc(x)
167 |
168 | return x
--------------------------------------------------------------------------------
/models/text_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.diffusion_transformer import DiTConVBlock
5 | from utils.mask import sequence_mask
6 |
7 | # modified from https://github.com/jaywalnut310/vits/blob/main/models.py
8 | class TextEncoder(nn.Module):
9 | def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels):
10 | super().__init__()
11 | self.n_vocab = n_vocab
12 | self.out_channels = out_channels
13 | self.hidden_channels = hidden_channels
14 | self.filter_channels = filter_channels
15 | self.n_heads = n_heads
16 | self.n_layers = n_layers
17 | self.kernel_size = kernel_size
18 | self.p_dropout = p_dropout
19 |
20 | self.scale = self.hidden_channels ** 0.5
21 |
22 | self.emb = nn.Embedding(n_vocab, hidden_channels)
23 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
24 |
25 | self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)])
26 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
27 |
28 | self.initialize_weights()
29 |
30 | def initialize_weights(self):
31 | for block in self.encoder:
32 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
33 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
34 |
35 | def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor):
36 | x = self.emb(x) * self.scale # [b, t, h]
37 | x = x.transpose(1, -1) # [b, h, t]
38 | x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype)
39 |
40 | for layer in self.encoder:
41 | x = layer(x, c, x_mask)
42 | mu_x = self.proj(x) * x_mask
43 |
44 | return x, mu_x, x_mask
45 |
--------------------------------------------------------------------------------
/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | from numpy import zeros, int32, float32
2 | from torch import from_numpy
3 |
4 | from .core import maximum_path_jit
5 |
6 |
7 | def maximum_path(neg_cent, mask):
8 | device = neg_cent.device
9 | dtype = neg_cent.dtype
10 | neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11 | path = zeros(neg_cent.shape, dtype=int32)
12 |
13 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
16 | return from_numpy(path).to(device=device, dtype=dtype)
17 |
--------------------------------------------------------------------------------
/monotonic_align/core.py:
--------------------------------------------------------------------------------
1 | import numba
2 |
3 |
4 | @numba.jit(
5 | numba.void(
6 | numba.int32[:, :, ::1],
7 | numba.float32[:, :, ::1],
8 | numba.int32[::1],
9 | numba.int32[::1],
10 | ),
11 | nopython=True,
12 | nogil=True,
13 | )
14 | def maximum_path_jit(paths, values, t_ys, t_xs):
15 | b = paths.shape[0]
16 | max_neg_val = -1e9
17 | for i in range(int(b)):
18 | path = paths[i]
19 | value = values[i]
20 | t_y = t_ys[i]
21 | t_x = t_xs[i]
22 |
23 | v_prev = v_cur = 0.0
24 | index = t_x - 1
25 |
26 | for y in range(t_y):
27 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28 | if x == y:
29 | v_cur = max_neg_val
30 | else:
31 | v_cur = value[y - 1, x]
32 | if x == 0:
33 | if y == 0:
34 | v_prev = 0.0
35 | else:
36 | v_prev = max_neg_val
37 | else:
38 | v_prev = value[y - 1, x - 1]
39 | value[y, x] += max(v_prev, v_cur)
40 |
41 | for y in range(t_y - 1, -1, -1):
42 | path[y, index] = 1
43 | if index != 0 and (
44 | index == y or value[y - 1, index] < value[y - 1, index - 1]
45 | ):
46 | index = index - 1
47 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import tqdm
4 | from dataclasses import dataclass, asdict
5 |
6 | import torch
7 | from torch.multiprocessing import Pool, set_start_method
8 | import torchaudio
9 |
10 | from config import MelConfig, TrainConfig
11 | from utils.audio import LogMelSpectrogram, load_and_resample_audio
12 |
13 | from text.mandarin import chinese_to_cnm3
14 | from text.english import english_to_ipa2
15 | from text.japanese import japanese_to_ipa2
16 |
17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
18 |
19 | @dataclass
20 | class DataConfig:
21 | input_filelist_path = './filelists/filelist.txt' # a filelist contains 'audiopath | text'
22 | output_filelist_path = './filelists/filelist.json' # path to save filelist
23 | output_feature_path = './stableTTS_datasets' # path to save resampled audios and mel features
24 | language = 'english' # chinese, japanese or english
25 | resample = False # waveform is not used in training, so save resampled results is not necessary.
26 |
27 | g2p_mapping = {
28 | 'chinese': chinese_to_cnm3,
29 | 'japanese': japanese_to_ipa2,
30 | 'english': english_to_ipa2,
31 | }
32 |
33 | data_config = DataConfig()
34 | train_config = TrainConfig()
35 | mel_config = MelConfig()
36 |
37 | input_filelist_path = data_config.input_filelist_path
38 | output_filelist_path = data_config.output_filelist_path
39 | output_feature_path = data_config.output_feature_path
40 |
41 | # Ensure output directories exist
42 | output_mel_dir = os.path.join(output_feature_path, 'mels')
43 | os.makedirs(output_mel_dir, exist_ok=True)
44 | os.makedirs(os.path.dirname(output_filelist_path), exist_ok=True)
45 |
46 | if data_config.resample:
47 | output_wav_dir = os.path.join(output_feature_path, 'waves')
48 | os.makedirs(output_wav_dir, exist_ok=True)
49 |
50 | mel_extractor = LogMelSpectrogram(**asdict(mel_config)).to(device)
51 |
52 | g2p = g2p_mapping.get(data_config.language)
53 |
54 | def load_filelist(path) -> list:
55 | file_list = []
56 | with open(path, 'r', encoding='utf-8') as f:
57 | for idx, line in enumerate(f):
58 | audio_path, text = line.strip().split('|', maxsplit=1)
59 | file_list.append((str(idx), audio_path, text))
60 | return file_list
61 |
62 | @ torch.inference_mode()
63 | def process_filelist(line) -> str:
64 | idx, audio_path, text = line
65 | audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) # shape: [1, time]
66 | if audio is not None:
67 | # get output path
68 | audio_name, _ = os.path.splitext(os.path.basename(audio_path))
69 |
70 | try:
71 | phone = g2p(text)
72 | if len(phone) > 0:
73 | mel = mel_extractor(audio.to(device)).cpu().squeeze(0) # shape: [n_mels, time // hop_length]
74 | output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt')
75 | torch.save(mel, output_mel_path)
76 |
77 | if data_config.resample:
78 | audio_path = os.path.join(output_wav_dir, f'{idx}_{audio_name}.wav')
79 | torchaudio.save(audio_path, audio.cpu(), mel_config.sample_rate)
80 | return json.dumps({'mel_path': output_mel_path, 'phone': phone, 'audio_path': audio_path, 'text': text, 'mel_length': mel.size(-1)}, ensure_ascii=False, allow_nan=False)
81 | except Exception as e:
82 | print(f'Error processing {audio_path}: {str(e)}')
83 |
84 |
85 | def main():
86 | set_start_method('spawn') # CUDA must use spawn method
87 | input_filelist = load_filelist(input_filelist_path)
88 | results = []
89 |
90 | with Pool(processes=2) as pool:
91 | for result in tqdm(pool.imap(process_filelist, input_filelist), total=len(input_filelist)):
92 | if result is not None:
93 | results.append(f'{result}\n')
94 |
95 | # save filelist
96 | with open(output_filelist_path, 'w', encoding='utf-8') as f:
97 | f.writelines(results)
98 | print(f"filelist file has been saved to {output_filelist_path}")
99 |
100 | # faster and use much less CPU
101 | torch.set_num_threads(1)
102 | torch.set_num_interop_threads(1)
103 |
104 | if __name__ == '__main__':
105 | main()
--------------------------------------------------------------------------------
/recipes/AiSHELL3.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import re
4 | from dataclasses import dataclass
5 | import concurrent.futures
6 |
7 | from tqdm.auto import tqdm
8 |
9 | # download_link: https://www.openslr.org/93/
10 | @dataclass
11 | class DataConfig:
12 | dataset_path = './raw_datasets/Aishell3/train/wav'
13 | txt_path = './raw_datasets/Aishell3/train/content.txt'
14 | output_filelist_path = './filelists/aishell3.txt'
15 |
16 | data_config = DataConfig()
17 |
18 | def process_filelist(line):
19 | dir_name, audio_path, text = line
20 | input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, dir_name, audio_path))
21 | if os.path.exists(input_audio_path):
22 | return f'{input_audio_path}|{text}\n'
23 |
24 | if __name__ == '__main__':
25 | filelist = []
26 | results = []
27 |
28 | with open(data_config.txt_path, 'r', encoding='utf-8') as f:
29 | for idx, line in enumerate(f):
30 | audio_path, text = line.strip().split(maxsplit=1)
31 | dir_name = audio_path[:7]
32 | text = re.sub(r'[a-zA-Z0-9\s]', '', text) # remove pinyin and tone
33 | filelist.append((dir_name, audio_path, text))
34 |
35 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
36 | futures = [executor.submit(process_filelist, line) for line in filelist]
37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)):
38 | result = future.result()
39 | if result is not None:
40 | results.append(result)
41 |
42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
45 | f.writelines(results)
--------------------------------------------------------------------------------
/recipes/BZNSYP_标贝女声.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import re
4 | from dataclasses import dataclass
5 | import concurrent.futures
6 |
7 | from tqdm.auto import tqdm
8 |
9 | # submit the form on: https://www.data-baker.com/data/index/TNtts/
10 | # then you will get the download link
11 | @dataclass
12 | class DataConfig:
13 | dataset_path = './raw_datasets/BZNSYP/Wave'
14 | txt_path = './raw_datasets/BZNSYP/ProsodyLabeling/000001-010000.txt'
15 | output_filelist_path = './filelists/bznsyp.txt'
16 |
17 | data_config = DataConfig()
18 |
19 | def process_filelist(line):
20 | audio_name, text = line.split('\t')
21 | text = re.sub('[#\d]+', '', text) # remove '#' and numbers
22 | input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, f'{audio_name}.wav'))
23 | if os.path.exists(input_audio_path):
24 | return f'{input_audio_path}|{text}\n'
25 |
26 | if __name__ == '__main__':
27 | filelist = []
28 | results = []
29 |
30 | with open(data_config.txt_path, 'r', encoding='utf-8') as f:
31 | for idx, line in enumerate(f):
32 | if idx % 2 == 0:
33 | filelist.append(line.strip())
34 |
35 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
36 | futures = [executor.submit(process_filelist, line) for line in filelist]
37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)):
38 | result = future.result()
39 | if result is not None:
40 | results.append(result)
41 |
42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
45 | f.writelines(results)
--------------------------------------------------------------------------------
/recipes/VCTK_huggingface.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | from pathlib import Path
4 | from dataclasses import dataclass
5 | import concurrent.futures
6 |
7 | from tqdm.auto import tqdm
8 | import pandas as pd
9 | import torchaudio
10 |
11 | # download_link: https://huggingface.co/datasets/CSTR-Edinburgh/vctk/tree/063f48e28abda80b2fdc4d4433af8a99e71bfe16
12 | # other huggingface TTS parquet datasets could use the same script
13 | @dataclass
14 | class DataConfig:
15 | dataset_path = './raw_datasets/VCTK'
16 | output_filelist_path = './filelists/VCTK.txt'
17 | output_audio_path = './raw_datasets/VCTK_audios' # to extract audios from parquet files
18 |
19 | data_config = DataConfig()
20 |
21 | def process_parquet(parquet_path: Path):
22 | df = pd.read_parquet(parquet_path)
23 | filelist = []
24 | for idx, data in tqdm(df.iterrows(), total=len(df)):
25 | audio = io.BytesIO(data['audio']['bytes'])
26 | audio, sample_rate = torchaudio.load(audio)
27 | text = data['text']
28 |
29 | path = os.path.abspath(os.path.join(data_config.output_audio_path, data['audio']['path']))
30 | torchaudio.save(path, audio, sample_rate)
31 |
32 | filelist.append(f'{path}|{text}\n')
33 |
34 | return filelist
35 |
36 | if __name__ == '__main__':
37 | filelist = []
38 | results = []
39 |
40 | dataset_path = Path(data_config.dataset_path)
41 | parquets = list(dataset_path.rglob('*.parquet'))
42 |
43 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
44 | futures = [executor.submit(process_parquet, parquet_path) for parquet_path in parquets]
45 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(parquets)):
46 | result = future.result()
47 | if result is not None:
48 | results.extend(result)
49 |
50 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
51 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
52 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
53 | f.writelines(results)
54 |
--------------------------------------------------------------------------------
/recipes/genshin_en_小虫哥ver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from dataclasses import dataclass
4 | import concurrent.futures
5 |
6 | from tqdm.auto import tqdm
7 | import openpyxl # use to open excel. run ! pip install openpyxl
8 |
9 | # download_link: https://www.bilibili.com/read/cv23965717
10 | @dataclass
11 | class DataConfig:
12 | dataset_path = './raw_datasets/Genshin_chinese4.5/原神语音包4.5(英)'
13 | excel_path = './raw_datasets/Genshin_chinese4.5/原神4.5语音包对应文本(英).xlsx'
14 | output_filelist_path = './filelists/genshin_en.txt'
15 |
16 | # 若文本中出现以下字符,基本和语音对不上
17 | FORBIDDEN_TEXTS = ["……", "{NICKNAME}", "#", "(", ")", "♪", "test", "{0}", "█", "*", "█", "+", "Gohus"]
18 | REPLACEMENTS = {"$UNRELEASED": ""}
19 | escaped_forbidden_texts = [re.escape(text) for text in FORBIDDEN_TEXTS]
20 | pattern = re.compile("|".join(escaped_forbidden_texts))
21 |
22 | data_config = DataConfig()
23 |
24 | def clean_text(text):
25 | cleaned_text = text
26 | if pattern.search(cleaned_text):
27 | return None
28 | for old, new in REPLACEMENTS.items():
29 | cleaned_text = cleaned_text.replace(old, new)
30 | return text
31 |
32 | def read_excel(excel):
33 | wb = openpyxl.load_workbook(excel)
34 | sheet_names = wb.sheetnames
35 | main_sheet = wb[sheet_names[0]]
36 | npc_names = [cell.value for cell in main_sheet['B'] if cell.value][1:]
37 | npc_audio_number = [cell.value for cell in main_sheet['C'] if cell.value][1:]
38 | return wb, npc_names, npc_audio_number
39 |
40 | def process_filelist(data):
41 | audio_path, text, npc_path = data
42 | input_audio_path = os.path.abspath(os.path.join(npc_path, audio_path))
43 | if os.path.exists(input_audio_path):
44 | text = clean_text(text)
45 | if text is not None:
46 | return f'{input_audio_path}|{text}\n'
47 |
48 | if __name__ == '__main__':
49 | wb, npc_names, npc_audio_number = read_excel(data_config.excel_path)
50 | datas_list = []
51 | results = []
52 |
53 | for index, npc_name in enumerate(tqdm(npc_names)):
54 | sheet = wb[npc_name]
55 | audio_names = [cell.value for cell in sheet['C'] if cell.value][1:]
56 | texts = [cell.value for cell in sheet['D'] if cell.value][1:]
57 | npc_path = os.path.join(data_config.dataset_path, npc_name)
58 | datas_list.extend([(audio_name, text, npc_path) for audio_name, text in zip(audio_names, texts)])
59 |
60 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
61 | futures = [executor.submit(process_filelist, data) for data in datas_list]
62 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(datas_list)):
63 | result = future.result()
64 | if result is not None:
65 | results.append(result)
66 |
67 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
68 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
69 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
70 | f.writelines(results)
--------------------------------------------------------------------------------
/recipes/genshin_zh_小虫哥ver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from dataclasses import dataclass
4 | import concurrent.futures
5 |
6 | from tqdm.auto import tqdm
7 | import openpyxl # use to open excel. run ! pip install openpyxl
8 |
9 | # download_link: https://www.bilibili.com/read/cv23965717
10 | @dataclass
11 | class DataConfig:
12 | dataset_path = './raw_datasets/Genshin_chinese4.5/原神语音包4.5(中)'
13 | excel_path = './raw_datasets/Genshin_chinese4.5/原神4.5语音包对应文本(中).xlsx'
14 | output_filelist_path = './filelists/genshin_zh.txt'
15 |
16 | # 若文本中出现以下字符,基本和语音对不上
17 | FORBIDDEN_TEXTS = ["……", "{NICKNAME}", "#", "(", ")", "♪", "test", "{0}", "█", "*", "█", "+", "Gohus"]
18 | REPLACEMENTS = {"$UNRELEASED": ""}
19 | escaped_forbidden_texts = [re.escape(text) for text in FORBIDDEN_TEXTS]
20 | pattern = re.compile("|".join(escaped_forbidden_texts))
21 |
22 | data_config = DataConfig()
23 |
24 | def clean_text(text):
25 | cleaned_text = text
26 | # 删去所有包含英文的台词
27 | if re.search(r'[A-Za-z0-9]', cleaned_text):
28 | return None
29 | if pattern.search(cleaned_text):
30 | return None
31 | for old, new in REPLACEMENTS.items():
32 | cleaned_text = cleaned_text.replace(old, new)
33 | return text
34 |
35 | def read_excel(excel):
36 | wb = openpyxl.load_workbook(excel)
37 | sheet_names = wb.sheetnames
38 | main_sheet = wb[sheet_names[0]]
39 | npc_names = [cell.value for cell in main_sheet['B'] if cell.value][1:]
40 | npc_audio_number = [cell.value for cell in main_sheet['C'] if cell.value][1:]
41 | return wb, npc_names, npc_audio_number
42 |
43 | def process_filelist(data):
44 | audio_path, text, npc_path = data
45 | input_audio_path = os.path.abspath(os.path.join(npc_path, audio_path))
46 | if os.path.exists(input_audio_path):
47 | text = clean_text(text)
48 | if text is not None:
49 | return f'{input_audio_path}|{text}\n'
50 |
51 | if __name__ == '__main__':
52 | wb, npc_names, npc_audio_number = read_excel(data_config.excel_path)
53 | datas_list = []
54 | results = []
55 |
56 | for index, npc_name in enumerate(tqdm(npc_names)):
57 | sheet = wb[npc_name]
58 | audio_names = [cell.value for cell in sheet['C'] if cell.value][1:]
59 | texts = [cell.value for cell in sheet['D'] if cell.value][1:]
60 | npc_path = os.path.join(data_config.dataset_path, npc_name)
61 | datas_list.extend([(audio_name, text, npc_path) for audio_name, text in zip(audio_names, texts)])
62 |
63 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor:
64 | futures = [executor.submit(process_filelist, data) for data in datas_list]
65 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(datas_list)):
66 | result = future.result()
67 | if result is not None:
68 | results.append(result)
69 |
70 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
71 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
72 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
73 | f.writelines(results)
--------------------------------------------------------------------------------
/recipes/hifi_tts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from pathlib import Path
4 | from dataclasses import dataclass
5 | import concurrent.futures
6 |
7 | from tqdm.auto import tqdm
8 |
9 | # download_link: https://www.openslr.org/109/
10 | @dataclass
11 | class DataConfig:
12 | dataset_path = './raw_datasets/hi_fi_tts_v0'
13 | output_filelist_path = './filelists/hifi_tts.txt'
14 |
15 | data_config = DataConfig()
16 |
17 | def process_filelist(speaker):
18 | filelist = []
19 | with open(speaker, 'r', encoding='utf-8') as f:
20 | for line in f:
21 | line = json.loads(line.strip())
22 | audio_path = os.path.abspath(os.path.join(data_config.dataset_path, line['audio_filepath']))
23 | text = line['text_normalized']
24 | if os.path.exists(audio_path):
25 | filelist.append(f'{audio_path}|{text}\n')
26 | return filelist
27 |
28 | if __name__ == '__main__':
29 | filelist = []
30 | results = []
31 |
32 | dataset_path = Path(data_config.dataset_path)
33 | speakers = list(dataset_path.rglob('*.json'))
34 |
35 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
36 | futures = [executor.submit(process_filelist, speaker) for speaker in speakers]
37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(speakers)):
38 | result = future.result()
39 | if result is not None:
40 | results.extend(result)
41 |
42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
45 | f.writelines(results)
--------------------------------------------------------------------------------
/recipes/libriTTS.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from dataclasses import dataclass
4 | import concurrent.futures
5 |
6 | from tqdm.auto import tqdm
7 |
8 | # download_link: https://openslr.org/60/
9 | @dataclass
10 | class DataConfig:
11 | dataset_path = './raw_datasets/LibriTTS/train-other-500'
12 | output_filelist_path = './filelists/libri_tts.txt'
13 |
14 | data_config = DataConfig()
15 |
16 | def process_filelist(wav_path: Path):
17 | text_path = wav_path.with_suffix('.normalized.txt')
18 | if text_path.exists():
19 | with open(text_path, 'r', encoding='utf-8') as f:
20 | text = f.read().strip()
21 | return f'{wav_path.as_posix()}|{text}\n'
22 |
23 | if __name__ == '__main__':
24 | filelist = []
25 | results = []
26 |
27 | dataset_path = Path(data_config.dataset_path)
28 | waves = list(dataset_path.rglob('*.wav'))
29 |
30 | with concurrent.futures.ProcessPoolExecutor(max_workers=8) as executor:
31 | futures = [executor.submit(process_filelist, wav_path) for wav_path in waves]
32 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(waves)):
33 | result = future.result()
34 | if result is not None:
35 | results.append(result)
36 |
37 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO
38 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True)
39 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f:
40 | f.writelines(results)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchaudio
3 |
4 | tqdm
5 | numpy
6 | soundfile # to make sure that torchaudio has at least one valid backend
7 |
8 | tensorboard
9 |
10 | # for monotonic_align
11 | numba
12 |
13 | # ODE-solver
14 | torchdiffeq
15 |
16 | # for g2p
17 | # chinese
18 | pypinyin
19 | jieba
20 | # english
21 | eng_to_ipa
22 | unidecode
23 | inflect
24 | # japanese
25 | # if pyopenjtalk fail to download open_jtalk_dic_utf_8-1.11.tar.gz, manually download and unzip the file below
26 | # https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz
27 | # and set os.environ['OPEN_JTALK_DICT_DIR'] to the folder path
28 | pyopenjtalk-prebuilt # if using python >= 3.12, install pyopenjtalk instead
29 |
30 | # for webui
31 | # gradio
32 | # matplotlib
33 |
34 |
--------------------------------------------------------------------------------
/text/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2017 Keith Ito
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from text import cleaners
3 | from text.symbols import symbols
4 |
5 |
6 | # Mappings from symbol to numeric ID and vice versa:
7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9 |
10 |
11 | def text_to_sequence(text, symbols, cleaner_names):
12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13 | Args:
14 | text: string to convert to a sequence
15 | cleaner_names: names of the cleaner functions to run the text through
16 | Returns:
17 | List of integers corresponding to the symbols in the text
18 | '''
19 | sequence = []
20 | symbol_to_id = {s: i for i, s in enumerate(symbols)}
21 | clean_text = _clean_text(text, cleaner_names)
22 | print(clean_text)
23 | print(f" length:{len(clean_text)}")
24 | for symbol in clean_text:
25 | if symbol not in symbol_to_id.keys():
26 | continue
27 | symbol_id = symbol_to_id[symbol]
28 | sequence += [symbol_id]
29 | print(f" length:{len(sequence)}")
30 | return sequence
31 |
32 |
33 | def cleaned_text_to_sequence(cleaned_text):
34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
35 | Args:
36 | text: string to convert to a sequence
37 | Returns:
38 | List of integers corresponding to the symbols in the text
39 | '''
40 | # symbol_to_id = {s: i for i, s in enumerate(symbols)}
41 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()]
42 | return sequence
43 |
44 | def cleaned_text_to_sequence_chinese(cleaned_text):
45 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
46 | Args:
47 | text: string to convert to a sequence
48 | Returns:
49 | List of integers corresponding to the symbols in the text
50 | '''
51 | # symbol_to_id = {s: i for i, s in enumerate(symbols)}
52 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()]
53 | return sequence
54 |
55 |
56 | def sequence_to_text(sequence):
57 | '''Converts a sequence of IDs back to a string'''
58 | result = ''
59 | for symbol_id in sequence:
60 | s = _id_to_symbol[symbol_id]
61 | result += s
62 | return result
63 |
64 |
65 | def _clean_text(text, cleaner_names):
66 | for name in cleaner_names:
67 | cleaner = getattr(cleaners, name)
68 | if not cleaner:
69 | raise Exception('Unknown cleaner: %s' % name)
70 | text = cleaner(text)
71 | return text
72 |
--------------------------------------------------------------------------------
/text/cleaners.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from text.english import english_to_ipa2
4 | from text.mandarin import chinese_to_cnm3
5 | from text.japanese import japanese_to_ipa2
6 |
7 | language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3}
8 |
9 | # 预编译正则表达式
10 | ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]')
11 | EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+')
12 | JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]')
13 | CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]')
14 |
15 | def detect_language(text: str, prev_lang=None):
16 | """
17 | 根据给定的文本检测语言
18 |
19 | :param text: 输入文本
20 | :param prev_lang: 上一个检测到的语言
21 | :return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces
22 | """
23 | if ZH_PATTERN.search(text): return 'ZH'
24 | if EN_PATTERN.search(text): return 'EN'
25 | if JP_PATTERN.search(text): return 'JA'
26 | if text.isspace(): return prev_lang # 若是空格,则返回前一个语言
27 | return None
28 |
29 | # auto detect language using re
30 | def cjke_cleaners4(text: str):
31 | """
32 | 根据文本内容自动检测语言并转换为IPA音标
33 |
34 | :param text: 输入文本
35 | :return: 转换为IPA音标的文本
36 | """
37 | text = CLEANER_PATTERN.sub('', text)
38 | pointer = 0
39 | output = ''
40 | current_language = detect_language(text[pointer])
41 |
42 | while pointer < len(text):
43 | temp_text = ''
44 | while pointer < len(text) and detect_language(text[pointer], current_language) == current_language:
45 | temp_text += text[pointer]
46 | pointer += 1
47 | if current_language == 'ZH':
48 | output += chinese_to_cnm3(temp_text)
49 | elif current_language == 'JA':
50 | output += japanese_to_ipa2(temp_text)
51 | elif current_language == 'EN':
52 | output += english_to_ipa2(temp_text)
53 | if pointer < len(text):
54 | current_language = detect_language(text[pointer])
55 |
56 | output = re.sub(r'\s+$', '', output)
57 | output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output)
58 | return output
59 |
--------------------------------------------------------------------------------
/text/cn2an/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.5.22"
2 |
3 | from .cn2an import Cn2An
4 | from .an2cn import An2Cn
5 | from .transform import Transform
6 |
7 | cn2an = Cn2An().cn2an
8 | an2cn = An2Cn().an2cn
9 | transform = Transform().transform
10 |
11 | __all__ = [
12 | "__version__",
13 | "cn2an",
14 | "an2cn",
15 | "transform"
16 | ]
--------------------------------------------------------------------------------
/text/cn2an/an2cn.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from warnings import warn
3 |
4 | # from proces import preprocess
5 |
6 | from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN
7 |
8 |
9 | class An2Cn(object):
10 | def __init__(self) -> None:
11 | self.all_num = "0123456789"
12 | self.number_low = NUMBER_LOW_AN2CN
13 | self.number_up = NUMBER_UP_AN2CN
14 | self.mode_list = ["low", "up", "rmb", "direct"]
15 |
16 | def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str:
17 | """阿拉伯数字转中文数字
18 |
19 | :param inputs: 阿拉伯数字
20 | :param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化
21 | :return: 中文数字
22 | """
23 | if inputs is not None and inputs != "":
24 | if mode not in self.mode_list:
25 | raise ValueError(f"mode 仅支持 {str(self.mode_list)} !")
26 |
27 | # 将数字转化为字符串,这里会有Python会自动做转化
28 | # 1. -> 1.0 1.00 -> 1.0 -0 -> 0
29 | if not isinstance(inputs, str):
30 | inputs = self.__number_to_string(inputs)
31 |
32 | # 数据预处理:
33 | # 1. 繁体转简体
34 | # 2. 全角转半角
35 | # inputs = preprocess(inputs, pipelines=[
36 | # "traditional_to_simplified",
37 | # "full_angle_to_half_angle"
38 | # ])
39 |
40 | # 检查数据是否有效
41 | self.__check_inputs_is_valid(inputs)
42 |
43 | # 判断正负
44 | if inputs[0] == "-":
45 | sign = "负"
46 | inputs = inputs[1:]
47 | else:
48 | sign = ""
49 |
50 | if mode == "direct":
51 | output = self.__direct_convert(inputs)
52 | else:
53 | # 切割整数部分和小数部分
54 | split_result = inputs.split(".")
55 | len_split_result = len(split_result)
56 | if len_split_result == 1:
57 | # 不包含小数的输入
58 | integer_data = split_result[0]
59 | if mode == "rmb":
60 | output = self.__integer_convert(integer_data, "up") + "元整"
61 | else:
62 | output = self.__integer_convert(integer_data, mode)
63 | elif len_split_result == 2:
64 | # 包含小数的输入
65 | integer_data, decimal_data = split_result
66 | if mode == "rmb":
67 | int_data = self.__integer_convert(integer_data, "up")
68 | dec_data = self.__decimal_convert(decimal_data, "up")
69 | len_dec_data = len(dec_data)
70 |
71 | if len_dec_data == 0:
72 | output = int_data + "元整"
73 | elif len_dec_data == 1:
74 | raise ValueError(f"异常输出:{dec_data}")
75 | elif len_dec_data == 2:
76 | if dec_data[1] != "零":
77 | if int_data == "零":
78 | output = dec_data[1] + "角"
79 | else:
80 | output = int_data + "元" + dec_data[1] + "角"
81 | else:
82 | output = int_data + "元整"
83 | else:
84 | if dec_data[1] != "零":
85 | if dec_data[2] != "零":
86 | if int_data == "零":
87 | output = dec_data[1] + "角" + dec_data[2] + "分"
88 | else:
89 | output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分"
90 | else:
91 | if int_data == "零":
92 | output = dec_data[1] + "角"
93 | else:
94 | output = int_data + "元" + dec_data[1] + "角"
95 | else:
96 | if dec_data[2] != "零":
97 | if int_data == "零":
98 | output = dec_data[2] + "分"
99 | else:
100 | output = int_data + "元" + "零" + dec_data[2] + "分"
101 | else:
102 | output = int_data + "元整"
103 | else:
104 | output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode)
105 | else:
106 | raise ValueError(f"输入格式错误:{inputs}!")
107 | else:
108 | raise ValueError("输入数据为空!")
109 |
110 | return sign + output
111 |
112 | def __direct_convert(self, inputs: str) -> str:
113 | _output = ""
114 | for d in inputs:
115 | if d == ".":
116 | _output += "点"
117 | else:
118 | _output += self.number_low[int(d)]
119 | return _output
120 |
121 | @staticmethod
122 | def __number_to_string(number_data: Union[int, float]) -> str:
123 | # 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005"
124 | string_data = str(number_data)
125 | if "e" in string_data:
126 | string_data_list = string_data.split("e")
127 | string_key = string_data_list[0]
128 | string_value = string_data_list[1]
129 | if string_value[0] == "-":
130 | string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key
131 | else:
132 | string_data = string_key + "0" * int(string_value)
133 | return string_data
134 |
135 | def __check_inputs_is_valid(self, check_data: str) -> None:
136 | # 检查输入数据是否在规定的字典中
137 | all_check_keys = self.all_num + ".-"
138 | for data in check_data:
139 | if data not in all_check_keys:
140 | raise ValueError(f"输入的数据不在转化范围内:{data}!")
141 |
142 | def __integer_convert(self, integer_data: str, mode: str) -> str:
143 | if mode == "low":
144 | numeral_list = NUMBER_LOW_AN2CN
145 | unit_list = UNIT_LOW_ORDER_AN2CN
146 | elif mode == "up":
147 | numeral_list = NUMBER_UP_AN2CN
148 | unit_list = UNIT_UP_ORDER_AN2CN
149 | else:
150 | raise ValueError(f"error mode: {mode}")
151 |
152 | # 去除前面的 0,比如 007 => 7
153 | integer_data = str(int(integer_data))
154 |
155 | len_integer_data = len(integer_data)
156 | if len_integer_data > len(unit_list):
157 | raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位")
158 |
159 | output_an = ""
160 | for i, d in enumerate(integer_data):
161 | if int(d):
162 | output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
163 | else:
164 | if not (len_integer_data - i - 1) % 4:
165 | output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1]
166 |
167 | if i > 0 and not output_an[-1] == "零":
168 | output_an += numeral_list[int(d)]
169 |
170 | output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \
171 | .strip("零")
172 |
173 | # 解决「一十几」问题
174 | if output_an[:2] in ["一十"]:
175 | output_an = output_an[1:]
176 |
177 | # 0 - 1 之间的小数
178 | if not output_an:
179 | output_an = "零"
180 |
181 | return output_an
182 |
183 | def __decimal_convert(self, decimal_data: str, o_mode: str) -> str:
184 | len_decimal_data = len(decimal_data)
185 |
186 | if len_decimal_data > 16:
187 | warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!")
188 | decimal_data = decimal_data[:16]
189 |
190 | if len_decimal_data:
191 | output_an = "点"
192 | else:
193 | output_an = ""
194 |
195 | if o_mode == "low":
196 | numeral_list = NUMBER_LOW_AN2CN
197 | elif o_mode == "up":
198 | numeral_list = NUMBER_UP_AN2CN
199 | else:
200 | raise ValueError(f"error mode: {o_mode}")
201 |
202 | for data in decimal_data:
203 | output_an += numeral_list[int(data)]
204 | return output_an
--------------------------------------------------------------------------------
/text/cn2an/conf.py:
--------------------------------------------------------------------------------
1 | NUMBER_CN2AN = {
2 | "零": 0,
3 | "〇": 0,
4 | "一": 1,
5 | "壹": 1,
6 | "幺": 1,
7 | "二": 2,
8 | "贰": 2,
9 | "两": 2,
10 | "三": 3,
11 | "叁": 3,
12 | "四": 4,
13 | "肆": 4,
14 | "五": 5,
15 | "伍": 5,
16 | "六": 6,
17 | "陆": 6,
18 | "七": 7,
19 | "柒": 7,
20 | "八": 8,
21 | "捌": 8,
22 | "九": 9,
23 | "玖": 9,
24 | }
25 | UNIT_CN2AN = {
26 | "十": 10,
27 | "拾": 10,
28 | "百": 100,
29 | "佰": 100,
30 | "千": 1000,
31 | "仟": 1000,
32 | "万": 10000,
33 | "亿": 100000000,
34 | }
35 | UNIT_LOW_AN2CN = {
36 | 10: "十",
37 | 100: "百",
38 | 1000: "千",
39 | 10000: "万",
40 | 100000000: "亿",
41 | }
42 | NUMBER_LOW_AN2CN = {
43 | 0: "零",
44 | 1: "一",
45 | 2: "二",
46 | 3: "三",
47 | 4: "四",
48 | 5: "五",
49 | 6: "六",
50 | 7: "七",
51 | 8: "八",
52 | 9: "九",
53 | }
54 | NUMBER_UP_AN2CN = {
55 | 0: "零",
56 | 1: "壹",
57 | 2: "贰",
58 | 3: "叁",
59 | 4: "肆",
60 | 5: "伍",
61 | 6: "陆",
62 | 7: "柒",
63 | 8: "捌",
64 | 9: "玖",
65 | }
66 | UNIT_LOW_ORDER_AN2CN = [
67 | "",
68 | "十",
69 | "百",
70 | "千",
71 | "万",
72 | "十",
73 | "百",
74 | "千",
75 | "亿",
76 | "十",
77 | "百",
78 | "千",
79 | "万",
80 | "十",
81 | "百",
82 | "千",
83 | ]
84 | UNIT_UP_ORDER_AN2CN = [
85 | "",
86 | "拾",
87 | "佰",
88 | "仟",
89 | "万",
90 | "拾",
91 | "佰",
92 | "仟",
93 | "亿",
94 | "拾",
95 | "佰",
96 | "仟",
97 | "万",
98 | "拾",
99 | "佰",
100 | "仟",
101 | ]
102 | STRICT_CN_NUMBER = {
103 | "零": "零",
104 | "一": "一壹",
105 | "二": "二贰",
106 | "三": "三叁",
107 | "四": "四肆",
108 | "五": "五伍",
109 | "六": "六陆",
110 | "七": "七柒",
111 | "八": "八捌",
112 | "九": "九玖",
113 | "十": "十拾",
114 | "百": "百佰",
115 | "千": "千仟",
116 | "万": "万",
117 | "亿": "亿",
118 | }
119 | NORMAL_CN_NUMBER = {
120 | "零": "零〇",
121 | "一": "一壹幺",
122 | "二": "二贰两",
123 | "三": "三叁仨",
124 | "四": "四肆",
125 | "五": "五伍",
126 | "六": "六陆",
127 | "七": "七柒",
128 | "八": "八捌",
129 | "九": "九玖",
130 | "十": "十拾",
131 | "百": "百佰",
132 | "千": "千仟",
133 | "万": "万",
134 | "亿": "亿",
135 | }
--------------------------------------------------------------------------------
/text/cn2an/transform.py:
--------------------------------------------------------------------------------
1 | import re
2 | from warnings import warn
3 |
4 | from .cn2an import Cn2An
5 | from .an2cn import An2Cn
6 | from .conf import UNIT_CN2AN
7 |
8 |
9 | class Transform(object):
10 | def __init__(self) -> None:
11 | self.all_num = "零一二三四五六七八九"
12 | self.all_unit = "".join(list(UNIT_CN2AN.keys()))
13 | self.cn2an = Cn2An().cn2an
14 | self.an2cn = An2Cn().an2cn
15 | self.cn_pattern = f"负?([{self.all_num}{self.all_unit}]+点)?[{self.all_num}{self.all_unit}]+"
16 | self.smart_cn_pattern = f"-?([0-9]+.)?[0-9]+[{self.all_unit}]+"
17 |
18 | def transform(self, inputs: str, method: str = "cn2an") -> str:
19 | if method == "cn2an":
20 | inputs = inputs.replace("廿", "二十").replace("半", "0.5").replace("两", "2")
21 | # date
22 | inputs = re.sub(
23 | fr"((({self.smart_cn_pattern})|({self.cn_pattern}))年)?([{self.all_num}十]+月)?([{self.all_num}十]+日)?",
24 | lambda x: self.__sub_util(x.group(), "cn2an", "date"), inputs)
25 | # fraction
26 | inputs = re.sub(fr"{self.cn_pattern}分之{self.cn_pattern}",
27 | lambda x: self.__sub_util(x.group(), "cn2an", "fraction"), inputs)
28 | # percent
29 | inputs = re.sub(fr"百分之{self.cn_pattern}",
30 | lambda x: self.__sub_util(x.group(), "cn2an", "percent"), inputs)
31 | # celsius
32 | inputs = re.sub(fr"{self.cn_pattern}摄氏度",
33 | lambda x: self.__sub_util(x.group(), "cn2an", "celsius"), inputs)
34 | # number
35 | output = re.sub(self.cn_pattern,
36 | lambda x: self.__sub_util(x.group(), "cn2an", "number"), inputs)
37 |
38 | elif method == "an2cn":
39 | # date
40 | inputs = re.sub(r"(\d{2,4}年)?(\d{1,2}月)?(\d{1,2}日)?",
41 | lambda x: self.__sub_util(x.group(), "an2cn", "date"), inputs)
42 | # fraction
43 | inputs = re.sub(r"\d+/\d+",
44 | lambda x: self.__sub_util(x.group(), "an2cn", "fraction"), inputs)
45 | # percent
46 | inputs = re.sub(r"-?(\d+\.)?\d+%",
47 | lambda x: self.__sub_util(x.group(), "an2cn", "percent"), inputs)
48 | # celsius
49 | inputs = re.sub(r"\d+℃",
50 | lambda x: self.__sub_util(x.group(), "an2cn", "celsius"), inputs)
51 | # number
52 | output = re.sub(r"-?(\d+\.)?\d+",
53 | lambda x: self.__sub_util(x.group(), "an2cn", "number"), inputs)
54 | else:
55 | raise ValueError(f"error method: {method}, only support 'cn2an' and 'an2cn'!")
56 |
57 | return output
58 |
59 | def __sub_util(self, inputs, method: str = "cn2an", sub_mode: str = "number") -> str:
60 | try:
61 | if inputs:
62 | if method == "cn2an":
63 | if sub_mode == "date":
64 | return re.sub(fr"(({self.smart_cn_pattern})|({self.cn_pattern}))",
65 | lambda x: str(self.cn2an(x.group(), "smart")), inputs)
66 | elif sub_mode == "fraction":
67 | if inputs[0] != "百":
68 | frac_result = re.sub(self.cn_pattern,
69 | lambda x: str(self.cn2an(x.group(), "smart")), inputs)
70 | numerator, denominator = frac_result.split("分之")
71 | return f"{denominator}/{numerator}"
72 | else:
73 | return inputs
74 | elif sub_mode == "percent":
75 | return re.sub(f"(?<=百分之){self.cn_pattern}",
76 | lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("百分之", "") + "%"
77 | elif sub_mode == "celsius":
78 | return re.sub(f"{self.cn_pattern}(?=摄氏度)",
79 | lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("摄氏度", "℃")
80 | elif sub_mode == "number":
81 | return str(self.cn2an(inputs, "smart"))
82 | else:
83 | raise Exception(f"error sub_mode: {sub_mode} !")
84 | else:
85 | if sub_mode == "date":
86 | inputs = re.sub(r"\d+(?=年)",
87 | lambda x: self.an2cn(x.group(), "direct"), inputs)
88 | return re.sub(r"\d+",
89 | lambda x: self.an2cn(x.group(), "low"), inputs)
90 | elif sub_mode == "fraction":
91 | frac_result = re.sub(r"\d+", lambda x: self.an2cn(x.group(), "low"), inputs)
92 | numerator, denominator = frac_result.split("/")
93 | return f"{denominator}分之{numerator}"
94 | elif sub_mode == "celsius":
95 | return self.an2cn(inputs[:-1], "low") + "摄氏度"
96 | elif sub_mode == "percent":
97 | return "百分之" + self.an2cn(inputs[:-1], "low")
98 | elif sub_mode == "number":
99 | return self.an2cn(inputs, "low")
100 | else:
101 | raise Exception(f"error sub_mode: {sub_mode} !")
102 | except Exception as e:
103 | warn(str(e))
104 | return inputs
--------------------------------------------------------------------------------
/text/custom_pypinyin_dict/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/text/custom_pypinyin_dict/cc_cedict_3.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import unicode_literals
3 |
4 | # Warning: Auto-generated file, don't edit.
5 | phrases_dict = {
6 | '𰻝𰻝面': [['biáng'], ['biáng'], ['miàn']],
7 | }
8 |
9 |
10 | from pypinyin import load_phrases_dict
11 |
12 |
13 | def load():
14 | load_phrases_dict(phrases_dict)
15 |
--------------------------------------------------------------------------------
/text/custom_pypinyin_dict/genshin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import unicode_literals
3 |
4 | phrases_dict = {
5 | '㐖毒': [['xié'], ['dú']],
6 | '若陀': [['rě'], ['tuó']],
7 | '平藏': [['píng'], ['zàng']],
8 | '派蒙': [['pài'], ['méng']],
9 | '安柏': [['ān'], ['bó']],
10 | '一斗': [['yī'], ['dǒu']]
11 | }
--------------------------------------------------------------------------------
/text/custom_pypinyin_dict/phrase_pinyin_data.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from __future__ import unicode_literals
3 |
4 | from pypinyin import load_phrases_dict
5 |
6 | from text.custom_pypinyin_dict import cc_cedict_0
7 | from text.custom_pypinyin_dict import cc_cedict_1
8 | from text.custom_pypinyin_dict import cc_cedict_2
9 | from text.custom_pypinyin_dict import cc_cedict_3
10 | from text.custom_pypinyin_dict import genshin
11 |
12 | phrases_dict = {}
13 | phrases_dict.update(cc_cedict_0.phrases_dict)
14 | phrases_dict.update(cc_cedict_1.phrases_dict)
15 | phrases_dict.update(cc_cedict_2.phrases_dict)
16 | phrases_dict.update(cc_cedict_3.phrases_dict)
17 | phrases_dict.update(genshin.phrases_dict)
18 |
19 | def load():
20 | load_phrases_dict(phrases_dict)
21 | print("加载自定义词典成功")
22 |
23 | if __name__ == '__main__':
24 | print(phrases_dict)
--------------------------------------------------------------------------------
/text/english.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | '''
4 | Cleaners are transformations that run over the input text at both training and eval time.
5 |
6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8 | 1. "english_cleaners" for English text
9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12 | the symbols in symbols.py to match your data).
13 | '''
14 |
15 |
16 | # Regular expression matching whitespace:
17 |
18 |
19 | import re
20 | import inflect
21 | from unidecode import unidecode
22 | import eng_to_ipa as ipa
23 | _inflect = inflect.engine()
24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
29 | _number_re = re.compile(r'[0-9]+')
30 |
31 | # List of (regular expression, replacement) pairs for abbreviations:
32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
33 | ('mrs', 'misess'),
34 | ('mr', 'mister'),
35 | ('dr', 'doctor'),
36 | ('st', 'saint'),
37 | ('co', 'company'),
38 | ('jr', 'junior'),
39 | ('maj', 'major'),
40 | ('gen', 'general'),
41 | ('drs', 'doctors'),
42 | ('rev', 'reverend'),
43 | ('lt', 'lieutenant'),
44 | ('hon', 'honorable'),
45 | ('sgt', 'sergeant'),
46 | ('capt', 'captain'),
47 | ('esq', 'esquire'),
48 | ('ltd', 'limited'),
49 | ('col', 'colonel'),
50 | ('ft', 'fort'),
51 | ]]
52 |
53 |
54 | # List of (ipa, lazy ipa) pairs:
55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
56 | ('r', 'ɹ'),
57 | ('æ', 'e'),
58 | ('ɑ', 'a'),
59 | ('ɔ', 'o'),
60 | ('ð', 'z'),
61 | ('θ', 's'),
62 | ('ɛ', 'e'),
63 | ('ɪ', 'i'),
64 | ('ʊ', 'u'),
65 | ('ʒ', 'ʥ'),
66 | ('ʤ', 'ʥ'),
67 | ('ˈ', '↓'),
68 | ]]
69 |
70 | # List of (ipa, lazy ipa2) pairs:
71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
72 | ('r', 'ɹ'),
73 | ('ð', 'z'),
74 | ('θ', 's'),
75 | ('ʒ', 'ʑ'),
76 | ('ʤ', 'dʑ'),
77 | ('ˈ', '↓'),
78 | ]]
79 |
80 | # List of (ipa, ipa2) pairs
81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
82 | ('r', 'ɹ'),
83 | ('ʤ', 'dʒ'),
84 | ('ʧ', 'tʃ')
85 | ]]
86 |
87 |
88 | def expand_abbreviations(text):
89 | for regex, replacement in _abbreviations:
90 | text = re.sub(regex, replacement, text)
91 | return text
92 |
93 |
94 | def collapse_whitespace(text):
95 | return re.sub(r'\s+', ' ', text)
96 |
97 |
98 | def _remove_commas(m):
99 | return m.group(1).replace(',', '')
100 |
101 |
102 | def _expand_decimal_point(m):
103 | return m.group(1).replace('.', ' point ')
104 |
105 |
106 | def _expand_dollars(m):
107 | match = m.group(1)
108 | parts = match.split('.')
109 | if len(parts) > 2:
110 | return match + ' dollars' # Unexpected format
111 | dollars = int(parts[0]) if parts[0] else 0
112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
113 | if dollars and cents:
114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
115 | cent_unit = 'cent' if cents == 1 else 'cents'
116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
117 | elif dollars:
118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars'
119 | return '%s %s' % (dollars, dollar_unit)
120 | elif cents:
121 | cent_unit = 'cent' if cents == 1 else 'cents'
122 | return '%s %s' % (cents, cent_unit)
123 | else:
124 | return 'zero dollars'
125 |
126 |
127 | def _expand_ordinal(m):
128 | return _inflect.number_to_words(m.group(0))
129 |
130 |
131 | def _expand_number(m):
132 | num = int(m.group(0))
133 | if num > 1000 and num < 3000:
134 | if num == 2000:
135 | return 'two thousand'
136 | elif num > 2000 and num < 2010:
137 | return 'two thousand ' + _inflect.number_to_words(num % 100)
138 | elif num % 100 == 0:
139 | return _inflect.number_to_words(num // 100) + ' hundred'
140 | else:
141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
142 | else:
143 | return _inflect.number_to_words(num, andword='')
144 |
145 |
146 | def normalize_numbers(text):
147 | text = re.sub(_comma_number_re, _remove_commas, text)
148 | text = re.sub(_pounds_re, r'\1 pounds', text)
149 | text = re.sub(_dollars_re, _expand_dollars, text)
150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
151 | text = re.sub(_ordinal_re, _expand_ordinal, text)
152 | text = re.sub(_number_re, _expand_number, text)
153 | return text
154 |
155 |
156 | def mark_dark_l(text):
157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
158 |
159 |
160 | def english_to_ipa(text):
161 | text = unidecode(text).lower()
162 | text = expand_abbreviations(text)
163 | text = normalize_numbers(text)
164 | phonemes = ipa.convert(text)
165 | phonemes = collapse_whitespace(phonemes)
166 | return phonemes
167 |
168 |
169 | def english_to_ipa2(text):
170 | text = english_to_ipa(text)
171 | text = mark_dark_l(text)
172 | for regex, replacement in _ipa_to_ipa2:
173 | text = re.sub(regex, replacement, text)
174 | return list(text.replace('...', '…'))
175 |
176 |
--------------------------------------------------------------------------------
/text/japanese.py:
--------------------------------------------------------------------------------
1 | import re
2 | from unidecode import unidecode
3 | import pyopenjtalk
4 |
5 |
6 | # Regular expression matching Japanese without punctuation marks:
7 | _japanese_characters = re.compile(
8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
9 |
10 | # Regular expression matching non-Japanese characters or punctuation marks:
11 | _japanese_marks = re.compile(
12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
13 |
14 | # List of (symbol, Japanese) pairs for marks:
15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
16 | ('%', 'パーセント')
17 | ]]
18 |
19 | # List of (romaji, ipa) pairs for marks:
20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
21 | ('ts', 'ʦ'),
22 | ('u', 'ɯ'),
23 | ('j', 'ʥ'),
24 | ('y', 'j'),
25 | ('ni', 'n^i'),
26 | ('nj', 'n^'),
27 | ('hi', 'çi'),
28 | ('hj', 'ç'),
29 | ('f', 'ɸ'),
30 | ('I', 'i*'),
31 | ('U', 'ɯ*'),
32 | ('r', 'ɾ')
33 | ]]
34 |
35 | # List of (romaji, ipa2) pairs for marks:
36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
37 | ('u', 'ɯ'),
38 | ('ʧ', 'tʃ'),
39 | ('j', 'dʑ'),
40 | ('y', 'j'),
41 | ('ni', 'n^i'),
42 | ('nj', 'n^'),
43 | ('hi', 'çi'),
44 | ('hj', 'ç'),
45 | ('f', 'ɸ'),
46 | ('I', 'i*'),
47 | ('U', 'ɯ*'),
48 | ('r', 'ɾ')
49 | ]]
50 |
51 | # List of (consonant, sokuon) pairs:
52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
53 | (r'Q([↑↓]*[kg])', r'k#\1'),
54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'),
55 | (r'Q([↑↓]*[sʃ])', r's\1'),
56 | (r'Q([↑↓]*[pb])', r'p#\1')
57 | ]]
58 |
59 | # List of (consonant, hatsuon) pairs:
60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
61 | (r'N([↑↓]*[pbm])', r'm\1'),
62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'),
63 | (r'N([↑↓]*[tdn])', r'n\1'),
64 | (r'N([↑↓]*[kg])', r'ŋ\1')
65 | ]]
66 |
67 |
68 | def symbols_to_japanese(text):
69 | for regex, replacement in _symbols_to_japanese:
70 | text = re.sub(regex, replacement, text)
71 | return text
72 |
73 |
74 | def japanese_to_romaji_with_accent(text):
75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
76 | text = symbols_to_japanese(text)
77 | sentences = re.split(_japanese_marks, text)
78 | marks = re.findall(_japanese_marks, text)
79 | text = ''
80 | for i, sentence in enumerate(sentences):
81 | if re.match(_japanese_characters, sentence):
82 | if text != '':
83 | text += ' '
84 | labels = pyopenjtalk.extract_fullcontext(sentence)
85 | for n, label in enumerate(labels):
86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
87 | if phoneme not in ['sil', 'pau']:
88 | text += phoneme.replace('ch', 'ʧ').replace('sh',
89 | 'ʃ').replace('cl', 'Q')
90 | else:
91 | continue
92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1))
95 | a3 = int(re.search(r"\+(\d+)/", label).group(1))
96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
97 | a2_next = -1
98 | else:
99 | a2_next = int(
100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
101 | # Accent phrase boundary
102 | if a3 == 1 and a2_next == 1:
103 | text += ' '
104 | # Falling
105 | elif a1 == 0 and a2_next == a2 + 1:
106 | text += '↓'
107 | # Rising
108 | elif a2 == 1 and a2_next == 2:
109 | text += '↑'
110 | if i < len(marks):
111 | text += unidecode(marks[i]).replace(' ', '')
112 | return text
113 |
114 |
115 | def get_real_sokuon(text):
116 | for regex, replacement in _real_sokuon:
117 | text = re.sub(regex, replacement, text)
118 | return text
119 |
120 |
121 | def get_real_hatsuon(text):
122 | for regex, replacement in _real_hatsuon:
123 | text = re.sub(regex, replacement, text)
124 | return text
125 |
126 |
127 | def japanese_to_ipa(text):
128 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
129 | text = re.sub(
130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
131 | text = get_real_sokuon(text)
132 | text = get_real_hatsuon(text)
133 | for regex, replacement in _romaji_to_ipa:
134 | text = re.sub(regex, replacement, text)
135 | return text
136 |
137 |
138 | def japanese_to_ipa2(text):
139 | text = japanese_to_romaji_with_accent(text).replace('...', '…')
140 | text = get_real_sokuon(text)
141 | text = get_real_hatsuon(text)
142 | for regex, replacement in _romaji_to_ipa2:
143 | text = re.sub(regex, replacement, text)
144 | return list(text)
145 |
146 |
147 | def japanese_to_ipa3(text):
148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
150 | text = re.sub(
151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
153 | return text
154 |
155 | if __name__ == '__main__':
156 | a = japanese_to_romaji_with_accent('こんにちは!はい、元気です。あなたは?')
157 | print(a)
158 |
--------------------------------------------------------------------------------
/text/mandarin.py:
--------------------------------------------------------------------------------
1 | import re
2 | from typing import Dict, List
3 | from pypinyin import lazy_pinyin, Style
4 | from .custom_pypinyin_dict import phrase_pinyin_data
5 | import jieba
6 | from .cn2an import an2cn
7 |
8 | # 加载自定义拼音词典数据
9 | phrase_pinyin_data.load()
10 |
11 | # 标点符号正则
12 | PUNC_MAP: Dict[str, str] = {
13 | ":": ",",
14 | ";": ",",
15 | ",": ",",
16 | "。": ".",
17 | "!": "!",
18 | "?": "?",
19 | "\n": ".",
20 | "·": ",",
21 | "、": ",",
22 | "$": ".",
23 | "/": ",",
24 | "“": "'",
25 | "”": "'",
26 | '"': "'",
27 | "‘": "'",
28 | "’": "'",
29 | "(": "'",
30 | ")": "'",
31 | "(": "'",
32 | ")": "'",
33 | "《": "'",
34 | "》": "'",
35 | "【": "'",
36 | "】": "'",
37 | "[": "'",
38 | "]": "'",
39 | "—": "-",
40 | "~": "~",
41 | "「": "'",
42 | "」": "'",
43 | "『": "'",
44 | "』": "'",
45 | }
46 |
47 | # from GPT_SoVITS.text.zh_normalization.text_normlization
48 | PUNC_MAP.update ({
49 | '/': '每',
50 | '①': '一',
51 | '②': '二',
52 | '③': '三',
53 | '④': '四',
54 | '⑤': '五',
55 | '⑥': '六',
56 | '⑦': '七',
57 | '⑧': '八',
58 | '⑨': '九',
59 | '⑩': '十',
60 | 'α': '阿尔法',
61 | 'β': '贝塔',
62 | 'γ': '伽玛',
63 | 'Γ': '伽玛',
64 | 'δ': '德尔塔',
65 | 'Δ': '德尔塔',
66 | 'ε': '艾普西龙',
67 | 'ζ': '捷塔',
68 | 'η': '依塔',
69 | 'θ': '西塔',
70 | 'Θ': '西塔',
71 | 'ι': '艾欧塔',
72 | 'κ': '喀帕',
73 | 'λ': '拉姆达',
74 | 'Λ': '拉姆达',
75 | 'μ': '缪',
76 | 'ν': '拗',
77 | 'ξ': '克西',
78 | 'Ξ': '克西',
79 | 'ο': '欧米克伦',
80 | 'π': '派',
81 | 'Π': '派',
82 | 'ρ': '肉',
83 | 'ς': '西格玛',
84 | 'σ': '西格玛',
85 | 'Σ': '西格玛',
86 | 'τ': '套',
87 | 'υ': '宇普西龙',
88 | 'φ': '服艾',
89 | 'Φ': '服艾',
90 | 'χ': '器',
91 | 'ψ': '普赛',
92 | 'Ψ': '普赛',
93 | 'ω': '欧米伽',
94 | 'Ω': '欧米伽',
95 | '+': '加',
96 | '-': '减',
97 | '×': '乘',
98 | '÷': '除',
99 | '=': '等',
100 |
101 | "嗯": "恩",
102 | "呣": "母"
103 | })
104 |
105 | PUNC_TABLE = str.maketrans(PUNC_MAP)
106 |
107 | # 数字正则化
108 | NUMBER_PATTERN: re.Pattern = re.compile(r'\d+(?:\.?\d+)?')
109 |
110 | # 阿拉伯数字转汉字
111 | def replace_number(match: re.Match) -> str:
112 | return an2cn(match.group())
113 |
114 | def normalize_number(text: str) -> str:
115 | return NUMBER_PATTERN.sub(replace_number, text)
116 |
117 | # get symbols of phones, not used
118 | def load_pinyin_symbols(path):
119 | pinyin_dict={}
120 | temp = []
121 | with open(path, "r", encoding='utf-8') as f:
122 | content = f.readlines()
123 | for line in content:
124 | cuts = line.strip().split(',')
125 | pinyin = cuts[0]
126 | phones = cuts[1].split(' ')
127 | pinyin_dict[pinyin] = phones
128 | temp.extend(phones)
129 | temp = list(set(temp))
130 | tone = []
131 | for phone in temp:
132 | for i in range(1, 6):
133 | phone2 = phone + str(i)
134 | tone.append(phone2)
135 | print(sorted(tone, key=lambda x: len(x)))
136 | return pinyin_dict
137 |
138 | def load_pinyin_dict(path: str) -> Dict[str, List[str]]:
139 | pinyin_dict = {}
140 | with open(path, "r", encoding='utf-8') as f:
141 | for line in f:
142 | key, value = line.strip().split(',', 1)
143 | pinyin_dict[key] = value.split()
144 | return pinyin_dict
145 |
146 | import os
147 | pinyin_dict = load_pinyin_dict(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cnm3', 'ds_CNM3.txt'))
148 | # pinyin_dict = load_pinyin_dict('text/cnm3/ds_CNM3.txt')
149 |
150 | def chinese_to_cnm3(text: str) -> List[str]:
151 | # 标点符号和数字正则化
152 | text = text.translate(PUNC_TABLE)
153 | text = normalize_number(text)
154 | # 过滤掉特殊字符
155 | text = re.sub(r'[#&@“”^_|\\]', '', text)
156 |
157 | words = jieba.lcut(text, cut_all=False)
158 |
159 | phones = []
160 | for word in words:
161 | pinyin_list: List[str] = lazy_pinyin(word, style=Style.TONE3, neutral_tone_with_five=True)
162 | for pinyin in pinyin_list:
163 | if pinyin[-1].isdigit():
164 | tone = pinyin[-1]
165 | syllable = pinyin[:-1]
166 | phone = pinyin_dict[syllable]
167 | phones.extend([ph + tone for ph in phone])
168 | elif pinyin[-1].isalpha():
169 | pass
170 | else:
171 | phones.extend(pinyin)
172 |
173 | return phones
--------------------------------------------------------------------------------
/text/symbols.py:
--------------------------------------------------------------------------------
1 | '''
2 | Defines the set of symbols used in text input to the model.
3 | '''
4 |
5 | # japanese_cleaners
6 | # _pad = '_'
7 | # _punctuation = ',.!?-'
8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ '
9 |
10 |
11 | '''# japanese_cleaners2
12 | _pad = '_'
13 | _punctuation = ',.!?-~…'
14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ '
15 | '''
16 |
17 |
18 | '''# korean_cleaners
19 | _pad = '_'
20 | _punctuation = ',.!?…~'
21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ '
22 | '''
23 |
24 | '''# chinese_cleaners
25 | _pad = '_'
26 | _punctuation = ',。!?—…'
27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '
28 | '''
29 |
30 | # # zh_ja_mixture_cleaners
31 | # _pad = '_'
32 | # _punctuation = ',.!?-~…'
33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ '
34 |
35 |
36 | '''# sanskrit_cleaners
37 | _pad = '_'
38 | _punctuation = '।'
39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ '
40 | '''
41 |
42 | '''# cjks_cleaners
43 | _pad = '_'
44 | _punctuation = ',.!?-~…'
45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ '
46 | '''
47 |
48 | '''# thai_cleaners
49 | _pad = '_'
50 | _punctuation = '.!? '
51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์'
52 | '''
53 |
54 | # # cjke_cleaners2
55 | _pad = '_'
56 | _punctuation = ',.!?-~…' + "'"
57 | _IPA_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ '
58 | _CNM3_letters = ['y1', 'y2', 'y3', 'y4', 'y5', 'n1', 'n2', 'n3', 'n4', 'n5', 'p1', 'p2', 'p3', 'p4', 'p5', 'x1', 'x2', 'x3', 'x4', 'x5', 'k1', 'k2', 'k3', 'k4', 'k5', 'l1', 'l2', 'l3', 'l4', 'l5', 'q1', 'q2', 'q3', 'q4', 'q5', 'w1', 'w2', 'w3', 'w4', 'w5', 'E1', 'E2', 'E3', 'E4', 'E5', 'b1', 'b2', 'b3', 'b4', 'b5', 'c1', 'c2', 'c3', 'c4', 'c5', 'z1', 'z2', 'z3', 'z4', 'z5', 'e1', 'e2', 'e3', 'e4', 'e5', 'f1', 'f2', 'f3', 'f4', 'f5', 's1', 's2', 's3', 's4', 's5', 'j1', 'j2', 'j3', 'j4', 'j5', 'o1', 'o2', 'o3', 'o4', 'o5', 'i1', 'i2', 'i3', 'i4', 'i5', 'd1', 'd2', 'd3', 'd4', 'd5', 'm1', 'm2', 'm3', 'm4', 'm5', 't1', 't2', 't3', 't4', 't5', 'h1', 'h2', 'h3', 'h4', 'h5', 'g1', 'g2', 'g3', 'g4', 'g5', 'v1', 'v2', 'v3', 'v4', 'v5', 'r1', 'r2', 'r3', 'r4', 'r5', 'a1', 'a2', 'a3', 'a4', 'a5', 'u1', 'u2', 'u3', 'u4', 'u5', 'I01', 'I02', 'I03', 'I04', 'I05', 'i01', 'i02', 'i03', 'i04', 'i05', 'uo1', 'uo2', 'uo3', 'uo4', 'uo5', 'o01', 'o02', 'o03', 'o04', 'o05', 'U01', 'U02', 'U03', 'U04', 'U05', 'v01', 'v02', 'v03', 'v04', 'v05', 'er1', 'er2', 'er3', 'er4', 'er5', 'A01', 'A02', 'A03', 'A04', 'A05', 'ai1', 'ai2', 'ai3', 'ai4', 'ai5', 'e01', 'e02', 'e03', 'e04', 'e05', 'sh1', 'sh2', 'sh3', 'sh4', 'sh5', 'an1', 'an2', 'an3', 'an4', 'an5', 'ou1', 'ou2', 'ou3', 'ou4', 'ou5', 'ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'a01', 'a02', 'a03', 'a04', 'a05', 'N01', 'N02', 'N03', 'N04', 'N05', 'ao1', 'ao2', 'ao3', 'ao4', 'ao5', 've1', 've2', 've3', 've4', 've5', 'ir1', 'ir2', 'ir3', 'ir4', 'ir5', 'ng1', 'ng2', 'ng3', 'ng4', 'ng5', 'ua1', 'ua2', 'ua3', 'ua4', 'ua5', 'zh1', 'zh2', 'zh3', 'zh4', 'zh5', 'O01', 'O02', 'O03', 'O04', 'O05', 'ie1', 'ie2', 'ie3', 'ie4', 'ie5', 'E01', 'E02', 'E03', 'E04', 'E05', 'ia1', 'ia2', 'ia3', 'ia4', 'ia5', 'iE01', 'iE02', 'iE03', 'iE04', 'iE05', 'ang1', 'ang2', 'ang3', 'ang4', 'ang5', 'ng01', 'ng02', 'ng03', 'ng04', 'ng05', 'io01', 'io02', 'io03', 'io04', 'io05', 'iA01', 'iA02', 'iA03', 'iA04', 'iA05', 'uA01', 'uA02', 'uA03', 'uA04', 'uA05', 'ong1', 'ong2', 'ong3', 'ong4', 'ong5', 'oo01', 'oo02', 'oo03', 'oo04', 'oo05', 'uE01', 'uE02', 'uE03', 'uE04', 'uE05', 'vE01', 'vE02', 'vE03', 'vE04', 'vE05', 'ue01', 'ue02', 'ue03', 'ue04', 'ue05', 'ua01', 'ua02', 'ua03', 'ua04', 'ua05', 'iO01', 'iO02', 'iO03', 'iO04', 'iO05']
59 | _additional = ['', '']
60 | # _CNM3_letters = []
61 |
62 |
63 | '''# shanghainese_cleaners
64 | _pad = '_'
65 | _punctuation = ',.!?…'
66 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 '
67 | '''
68 |
69 | '''# chinese_dialect_cleaners
70 | _pad = '_'
71 | _punctuation = ',.!?~…─'
72 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ '
73 | '''
74 |
75 | # Export all symbols:
76 | symbols = [_pad] + list(_punctuation) + list(_IPA_letters) + _CNM3_letters + _additional
77 |
78 | # Special symbol ids
79 | SPACE_ID = symbols.index(" ")
80 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
3 |
4 | import torch
5 | import torch.optim as optim
6 | import torch.distributed as dist
7 | from torch.nn.parallel import DistributedDataParallel as DDP
8 | from torch.utils.data import DataLoader
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from tqdm import tqdm
12 | from dataclasses import asdict
13 |
14 | from datas.dataset import StableDataset, collate_fn
15 | from datas.sampler import DistributedBucketSampler
16 | from text import symbols
17 | from config import MelConfig, ModelConfig, TrainConfig
18 | from models.model import StableTTS
19 |
20 | from utils.scheduler import get_cosine_schedule_with_warmup
21 | from utils.load import continue_training
22 |
23 | torch.backends.cudnn.benchmark = True
24 |
25 | def setup(rank, world_size):
26 | os.environ['MASTER_ADDR'] = 'localhost'
27 | os.environ['MASTER_PORT'] = '12345'
28 | dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)
29 |
30 | def cleanup():
31 | dist.destroy_process_group()
32 |
33 | def _init_config(model_config: ModelConfig, mel_config: MelConfig, train_config: TrainConfig):
34 |
35 | if not os.path.exists(train_config.model_save_path):
36 | print(f'Creating {train_config.model_save_path}')
37 | os.makedirs(train_config.model_save_path, exist_ok=True)
38 |
39 | def train(rank, world_size):
40 | setup(rank, world_size)
41 | torch.cuda.set_device(rank)
42 |
43 | model_config = ModelConfig()
44 | mel_config = MelConfig()
45 | train_config = TrainConfig()
46 |
47 | _init_config(model_config, mel_config, train_config)
48 |
49 | model = StableTTS(len(symbols), mel_config.n_mels, **asdict(model_config)).to(rank)
50 |
51 | model = DDP(model, device_ids=[rank])
52 |
53 | train_dataset = StableDataset(train_config.train_dataset_path, mel_config.hop_length)
54 | train_sampler = DistributedBucketSampler(train_dataset, train_config.batch_size, [32,300,400,500,600,700,800,900,1000], num_replicas=world_size, rank=rank)
55 | train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True, collate_fn=collate_fn, persistent_workers=True)
56 |
57 | if rank == 0:
58 | writer = SummaryWriter(train_config.log_dir)
59 |
60 | optimizer = optim.AdamW(model.parameters(), lr=train_config.learning_rate)
61 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
62 |
63 | # load latest checkpoints if possible
64 | current_epoch = continue_training(train_config.model_save_path, model, optimizer)
65 |
66 | model.train()
67 | for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times
68 | train_dataloader.batch_sampler.set_epoch(epoch)
69 | if rank == 0:
70 | dataloader = tqdm(train_dataloader)
71 | else:
72 | dataloader = train_dataloader
73 |
74 | for batch_idx, datas in enumerate(dataloader):
75 | datas = [data.to(rank, non_blocking=True) for data in datas]
76 | x, x_lengths, y, y_lengths, z, z_lengths = datas
77 | optimizer.zero_grad()
78 | dur_loss, diff_loss, prior_loss, _ = model(x, x_lengths, y, y_lengths, z, z_lengths)
79 | loss = dur_loss + diff_loss + prior_loss
80 | loss.backward()
81 | optimizer.step()
82 | scheduler.step()
83 |
84 | if rank == 0 and batch_idx % train_config.log_interval == 0:
85 | steps = epoch * len(dataloader) + batch_idx
86 | writer.add_scalar("training/diff_loss", diff_loss.item(), steps)
87 | writer.add_scalar("training/dur_loss", dur_loss.item(), steps)
88 | writer.add_scalar("training/prior_loss", prior_loss.item(), steps)
89 | writer.add_scalar("learning_rate/learning_rate", scheduler.get_last_lr()[0], steps)
90 |
91 | if rank == 0 and epoch % train_config.save_interval == 0:
92 | torch.save(model.module.state_dict(), os.path.join(train_config.model_save_path, f'checkpoint_{epoch}.pt'))
93 | torch.save(optimizer.state_dict(), os.path.join(train_config.model_save_path, f'optimizer_{epoch}.pt'))
94 | print(f"Rank {rank}, Epoch {epoch}, Loss {loss.item()}")
95 |
96 | cleanup()
97 |
98 | torch.set_num_threads(1)
99 | torch.set_num_interop_threads(1)
100 |
101 | if __name__ == "__main__":
102 | world_size = torch.cuda.device_count()
103 | torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/utils/__init__.py
--------------------------------------------------------------------------------
/utils/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | import torchaudio
5 |
6 | class LinearSpectrogram(nn.Module):
7 | def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode):
8 | super().__init__()
9 |
10 | self.n_fft = n_fft
11 | self.win_length = win_length
12 | self.hop_length = hop_length
13 | self.pad = pad
14 | self.center = center
15 | self.pad_mode = pad_mode
16 |
17 | self.register_buffer("window", torch.hann_window(win_length))
18 |
19 | def forward(self, waveform: Tensor) -> Tensor:
20 | if waveform.ndim == 3:
21 | waveform = waveform.squeeze(1)
22 | waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1)
23 | spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True)
24 | spec = torch.view_as_real(spec)
25 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
26 | return spec
27 |
28 |
29 | class LogMelSpectrogram(nn.Module):
30 | def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale):
31 | super().__init__()
32 | self.sample_rate = sample_rate
33 | self.n_fft = n_fft
34 | self.win_length = win_length
35 | self.hop_length = hop_length
36 | self.f_min = f_min
37 | self.f_max = f_max
38 | self.pad = pad
39 | self.n_mels = n_mels
40 | self.center = center
41 | self.pad_mode = pad_mode
42 | self.mel_scale = mel_scale
43 |
44 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode)
45 | self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale)
46 |
47 | def compress(self, x: Tensor) -> Tensor:
48 | return torch.log(torch.clamp(x, min=1e-5))
49 |
50 | def decompress(self, x: Tensor) -> Tensor:
51 | return torch.exp(x)
52 |
53 | def forward(self, x: Tensor) -> Tensor:
54 | linear_spec = self.spectrogram(x)
55 | x = self.mel_scale(linear_spec)
56 | x = self.compress(x)
57 | return x
58 |
59 | def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor:
60 | try:
61 | y, sr = torchaudio.load(audio_path)
62 | except Exception as e:
63 | print(str(e))
64 | return None
65 |
66 | y.to(device)
67 | # Convert to mono
68 | if y.size(0) > 1:
69 | y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
70 |
71 | # resample audio to target sample_rate
72 | if sr != target_sr:
73 | y = torchaudio.functional.resample(y, sr, target_sr)
74 | return y
--------------------------------------------------------------------------------
/utils/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 |
7 | def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int:
8 | """load the latest checkpoints and optimizers"""
9 | model_dict = {}
10 | optimizer_dict = {}
11 |
12 | # globt all the checkpoints in the directory
13 | for file in os.listdir(checkpoint_path):
14 | if file.endswith(".pt") and '_' in file:
15 | name, epoch_str = file.rsplit('_', 1)
16 | epoch = int(epoch_str.split('.')[0])
17 |
18 | if name.startswith("checkpoint"):
19 | model_dict[epoch] = file
20 | elif name.startswith("optimizer"):
21 | optimizer_dict[epoch] = file
22 |
23 | # get the largest epoch
24 | common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys())
25 | if common_epochs:
26 | max_epoch = max(common_epochs)
27 | model_path = os.path.join(checkpoint_path, model_dict[max_epoch])
28 | optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch])
29 |
30 | # load model and optimizer
31 | model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
32 | optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu'))
33 |
34 | print(f'resume model and optimizer from {max_epoch} epoch')
35 | return max_epoch + 1
36 |
37 | else:
38 | # load pretrained checkpoint
39 | if model_dict:
40 | model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())])
41 | model.module.load_state_dict(torch.load(model_path, map_location='cpu'))
42 |
43 | return 0
--------------------------------------------------------------------------------
/utils/mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121
4 | def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor:
5 | if max_length is None:
6 | max_length = length.max()
7 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
8 | return x.unsqueeze(0) < length.unsqueeze(1)
--------------------------------------------------------------------------------
/vocoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/__init__.py
--------------------------------------------------------------------------------
/vocoders/ffgan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/ffgan/__init__.py
--------------------------------------------------------------------------------
/vocoders/ffgan/backbone.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 |
5 |
6 | # DropPath copied from timm library
7 | def drop_path(
8 | x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
9 | ):
10 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
11 |
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 |
18 | """ # noqa: E501
19 |
20 | if drop_prob == 0.0 or not training:
21 | return x
22 | keep_prob = 1 - drop_prob
23 | shape = (x.shape[0],) + (1,) * (
24 | x.ndim - 1
25 | ) # work with diff dim tensors, not just 2D ConvNets
26 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
27 | if keep_prob > 0.0 and scale_by_keep:
28 | random_tensor.div_(keep_prob)
29 | return x * random_tensor
30 |
31 |
32 | class DropPath(nn.Module):
33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
34 |
35 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
36 | super(DropPath, self).__init__()
37 | self.drop_prob = drop_prob
38 | self.scale_by_keep = scale_by_keep
39 |
40 | def forward(self, x):
41 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
42 |
43 | def extra_repr(self):
44 | return f"drop_prob={round(self.drop_prob,3):0.3f}"
45 |
46 |
47 | class LayerNorm(nn.Module):
48 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
49 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
50 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
51 | with shape (batch_size, channels, height, width).
52 | """ # noqa: E501
53 |
54 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
55 | super().__init__()
56 | self.weight = nn.Parameter(torch.ones(normalized_shape))
57 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
58 | self.eps = eps
59 | self.data_format = data_format
60 | if self.data_format not in ["channels_last", "channels_first"]:
61 | raise NotImplementedError
62 | self.normalized_shape = (normalized_shape,)
63 |
64 | def forward(self, x):
65 | if self.data_format == "channels_last":
66 | return F.layer_norm(
67 | x, self.normalized_shape, self.weight, self.bias, self.eps
68 | )
69 | elif self.data_format == "channels_first":
70 | u = x.mean(1, keepdim=True)
71 | s = (x - u).pow(2).mean(1, keepdim=True)
72 | x = (x - u) / torch.sqrt(s + self.eps)
73 | x = self.weight[:, None] * x + self.bias[:, None]
74 | return x
75 |
76 |
77 | # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py
78 | class ConvNeXtBlock(nn.Module):
79 | r"""ConvNeXt Block. There are two equivalent implementations:
80 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
81 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
82 | We use (2) as we find it slightly faster in PyTorch
83 |
84 | Args:
85 | dim (int): Number of input channels.
86 | drop_path (float): Stochastic depth rate. Default: 0.0
87 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
88 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
89 | kernel_size (int): Kernel size for depthwise conv. Default: 7.
90 | dilation (int): Dilation for depthwise conv. Default: 1.
91 | """ # noqa: E501
92 |
93 | def __init__(
94 | self,
95 | dim: int,
96 | drop_path: float = 0.0,
97 | layer_scale_init_value: float = 1e-6,
98 | mlp_ratio: float = 4.0,
99 | kernel_size: int = 7,
100 | dilation: int = 1,
101 | ):
102 | super().__init__()
103 |
104 | self.dwconv = nn.Conv1d(
105 | dim,
106 | dim,
107 | kernel_size=kernel_size,
108 | padding=int(dilation * (kernel_size - 1) / 2),
109 | groups=dim,
110 | ) # depthwise conv
111 | self.norm = LayerNorm(dim, eps=1e-6)
112 | self.pwconv1 = nn.Linear(
113 | dim, int(mlp_ratio * dim)
114 | ) # pointwise/1x1 convs, implemented with linear layers
115 | self.act = nn.GELU()
116 | self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
117 | self.gamma = (
118 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
119 | if layer_scale_init_value > 0
120 | else None
121 | )
122 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
123 |
124 | def forward(self, x, apply_residual: bool = True):
125 | input = x
126 |
127 | x = self.dwconv(x)
128 | x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
129 | x = self.norm(x)
130 | x = self.pwconv1(x)
131 | x = self.act(x)
132 | x = self.pwconv2(x)
133 |
134 | if self.gamma is not None:
135 | x = self.gamma * x
136 |
137 | x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
138 | x = self.drop_path(x)
139 |
140 | if apply_residual:
141 | x = input + x
142 |
143 | return x
144 |
145 |
146 | class ConvNeXtEncoder(nn.Module):
147 | def __init__(
148 | self,
149 | input_channels: int = 3,
150 | depths: list[int] = [3, 3, 9, 3],
151 | dims: list[int] = [96, 192, 384, 768],
152 | drop_path_rate: float = 0.0,
153 | layer_scale_init_value: float = 1e-6,
154 | kernel_size: int = 7,
155 | ):
156 | super().__init__()
157 | assert len(depths) == len(dims)
158 |
159 | self.downsample_layers = nn.ModuleList()
160 | stem = nn.Sequential(
161 | nn.Conv1d(
162 | input_channels,
163 | dims[0],
164 | kernel_size=kernel_size,
165 | padding=kernel_size // 2,
166 | padding_mode="zeros",
167 | ),
168 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
169 | )
170 | self.downsample_layers.append(stem)
171 |
172 | for i in range(len(depths) - 1):
173 | mid_layer = nn.Sequential(
174 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
175 | nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
176 | )
177 | self.downsample_layers.append(mid_layer)
178 |
179 | self.stages = nn.ModuleList()
180 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
181 |
182 | cur = 0
183 | for i in range(len(depths)):
184 | stage = nn.Sequential(
185 | *[
186 | ConvNeXtBlock(
187 | dim=dims[i],
188 | drop_path=dp_rates[cur + j],
189 | layer_scale_init_value=layer_scale_init_value,
190 | kernel_size=kernel_size,
191 | )
192 | for j in range(depths[i])
193 | ]
194 | )
195 | self.stages.append(stage)
196 | cur += depths[i]
197 |
198 | self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
199 | self.apply(self._init_weights)
200 |
201 | def _init_weights(self, m):
202 | if isinstance(m, (nn.Conv1d, nn.Linear)):
203 | nn.init.trunc_normal_(m.weight, std=0.02)
204 | nn.init.constant_(m.bias, 0)
205 |
206 | def forward(
207 | self,
208 | x: torch.Tensor,
209 | ) -> torch.Tensor:
210 | for i in range(len(self.downsample_layers)):
211 | x = self.downsample_layers[i](x)
212 | x = self.stages[i](x)
213 |
214 | return self.norm(x)
215 |
--------------------------------------------------------------------------------
/vocoders/ffgan/head.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from math import prod
3 | from typing import Callable
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn import Conv1d
10 | from torch.nn.utils.parametrizations import weight_norm
11 | from torch.nn.utils.parametrize import remove_parametrizations
12 | from torch.utils.checkpoint import checkpoint
13 |
14 |
15 | def init_weights(m, mean=0.0, std=0.01):
16 | classname = m.__class__.__name__
17 | if classname.find("Conv") != -1:
18 | m.weight.data.normal_(mean, std)
19 |
20 |
21 | def get_padding(kernel_size, dilation=1):
22 | return (kernel_size * dilation - dilation) // 2
23 |
24 |
25 | class ResBlock1(torch.nn.Module):
26 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
27 | super().__init__()
28 |
29 | self.convs1 = nn.ModuleList(
30 | [
31 | weight_norm(
32 | Conv1d(
33 | channels,
34 | channels,
35 | kernel_size,
36 | 1,
37 | dilation=dilation[0],
38 | padding=get_padding(kernel_size, dilation[0]),
39 | )
40 | ),
41 | weight_norm(
42 | Conv1d(
43 | channels,
44 | channels,
45 | kernel_size,
46 | 1,
47 | dilation=dilation[1],
48 | padding=get_padding(kernel_size, dilation[1]),
49 | )
50 | ),
51 | weight_norm(
52 | Conv1d(
53 | channels,
54 | channels,
55 | kernel_size,
56 | 1,
57 | dilation=dilation[2],
58 | padding=get_padding(kernel_size, dilation[2]),
59 | )
60 | ),
61 | ]
62 | )
63 | self.convs1.apply(init_weights)
64 |
65 | self.convs2 = nn.ModuleList(
66 | [
67 | weight_norm(
68 | Conv1d(
69 | channels,
70 | channels,
71 | kernel_size,
72 | 1,
73 | dilation=1,
74 | padding=get_padding(kernel_size, 1),
75 | )
76 | ),
77 | weight_norm(
78 | Conv1d(
79 | channels,
80 | channels,
81 | kernel_size,
82 | 1,
83 | dilation=1,
84 | padding=get_padding(kernel_size, 1),
85 | )
86 | ),
87 | weight_norm(
88 | Conv1d(
89 | channels,
90 | channels,
91 | kernel_size,
92 | 1,
93 | dilation=1,
94 | padding=get_padding(kernel_size, 1),
95 | )
96 | ),
97 | ]
98 | )
99 | self.convs2.apply(init_weights)
100 |
101 | def forward(self, x):
102 | for c1, c2 in zip(self.convs1, self.convs2):
103 | xt = F.silu(x)
104 | xt = c1(xt)
105 | xt = F.silu(xt)
106 | xt = c2(xt)
107 | x = xt + x
108 | return x
109 |
110 | def remove_parametrizations(self):
111 | for conv in self.convs1:
112 | remove_parametrizations(conv)
113 | for conv in self.convs2:
114 | remove_parametrizations(conv)
115 |
116 |
117 | class ParralelBlock(nn.Module):
118 | def __init__(
119 | self,
120 | channels: int,
121 | kernel_sizes: tuple[int] = (3, 7, 11),
122 | dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
123 | ):
124 | super().__init__()
125 |
126 | assert len(kernel_sizes) == len(dilation_sizes)
127 |
128 | self.blocks = nn.ModuleList()
129 | for k, d in zip(kernel_sizes, dilation_sizes):
130 | self.blocks.append(ResBlock1(channels, k, d))
131 |
132 | def forward(self, x):
133 | return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
134 |
135 |
136 | class HiFiGANGenerator(nn.Module):
137 | def __init__(
138 | self,
139 | *,
140 | hop_length: int = 512,
141 | upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
142 | upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
143 | resblock_kernel_sizes: tuple[int] = (3, 7, 11),
144 | resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
145 | num_mels: int = 128,
146 | upsample_initial_channel: int = 512,
147 | use_template: bool = True,
148 | pre_conv_kernel_size: int = 7,
149 | post_conv_kernel_size: int = 7,
150 | post_activation: Callable = partial(nn.SiLU, inplace=True),
151 | ):
152 | super().__init__()
153 |
154 | assert (
155 | prod(upsample_rates) == hop_length
156 | ), f"hop_length must be {prod(upsample_rates)}"
157 |
158 | self.conv_pre = weight_norm(
159 | nn.Conv1d(
160 | num_mels,
161 | upsample_initial_channel,
162 | pre_conv_kernel_size,
163 | 1,
164 | padding=get_padding(pre_conv_kernel_size),
165 | )
166 | )
167 |
168 | self.num_upsamples = len(upsample_rates)
169 | self.num_kernels = len(resblock_kernel_sizes)
170 |
171 | self.noise_convs = nn.ModuleList()
172 | self.use_template = use_template
173 | self.ups = nn.ModuleList()
174 |
175 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
176 | c_cur = upsample_initial_channel // (2 ** (i + 1))
177 | self.ups.append(
178 | weight_norm(
179 | nn.ConvTranspose1d(
180 | upsample_initial_channel // (2**i),
181 | upsample_initial_channel // (2 ** (i + 1)),
182 | k,
183 | u,
184 | padding=(k - u) // 2,
185 | )
186 | )
187 | )
188 |
189 | if not use_template:
190 | continue
191 |
192 | if i + 1 < len(upsample_rates):
193 | stride_f0 = np.prod(upsample_rates[i + 1 :])
194 | self.noise_convs.append(
195 | Conv1d(
196 | 1,
197 | c_cur,
198 | kernel_size=stride_f0 * 2,
199 | stride=stride_f0,
200 | padding=stride_f0 // 2,
201 | )
202 | )
203 | else:
204 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
205 |
206 | self.resblocks = nn.ModuleList()
207 | for i in range(len(self.ups)):
208 | ch = upsample_initial_channel // (2 ** (i + 1))
209 | self.resblocks.append(
210 | ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
211 | )
212 |
213 | self.activation_post = post_activation()
214 | self.conv_post = weight_norm(
215 | nn.Conv1d(
216 | ch,
217 | 1,
218 | post_conv_kernel_size,
219 | 1,
220 | padding=get_padding(post_conv_kernel_size),
221 | )
222 | )
223 | self.ups.apply(init_weights)
224 | self.conv_post.apply(init_weights)
225 |
226 | def forward(self, x, template=None):
227 | x = self.conv_pre(x)
228 |
229 | for i in range(self.num_upsamples):
230 | x = F.silu(x, inplace=True)
231 | x = self.ups[i](x)
232 |
233 | if self.use_template:
234 | x = x + self.noise_convs[i](template)
235 |
236 | if self.training and self.checkpointing:
237 | x = checkpoint(
238 | self.resblocks[i],
239 | x,
240 | use_reentrant=False,
241 | )
242 | else:
243 | x = self.resblocks[i](x)
244 |
245 | x = self.activation_post(x)
246 | x = self.conv_post(x)
247 | x = torch.tanh(x)
248 |
249 | return x
250 |
251 | def remove_parametrizations(self):
252 | for up in self.ups:
253 | remove_parametrizations(up)
254 | for block in self.resblocks:
255 | block.remove_parametrizations()
256 | remove_parametrizations(self.conv_pre)
257 | remove_parametrizations(self.conv_post)
258 |
--------------------------------------------------------------------------------
/vocoders/ffgan/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .backbone import ConvNeXtEncoder
5 | from .head import HiFiGANGenerator
6 |
7 | config_dict = {
8 | "backbone": {
9 | # "input_channels": "${model.num_mels}",
10 | "input_channels": 128,
11 | "depths": [3, 3, 9, 3],
12 | "dims": [128, 256, 384, 512],
13 | "drop_path_rate": 0.2,
14 | "kernel_size": 7,
15 | },
16 | "head": {
17 | # "hop_length": "${model.hop_length}",
18 | "hop_length": 512,
19 | "upsample_rates": [8, 8, 2, 2, 2],
20 | "upsample_kernel_sizes": [16, 16, 4, 4, 4],
21 | "resblock_kernel_sizes": [3, 7, 11],
22 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
23 | "num_mels": 512, # consistent with the output of the backbone
24 | "upsample_initial_channel": 512,
25 | "use_template": False,
26 | "pre_conv_kernel_size": 13,
27 | "post_conv_kernel_size": 13,
28 | }
29 | }
30 |
31 | # download_link: https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt
32 | class FireflyGANBaseWrapper(nn.Module):
33 | def __init__(self, model_path):
34 | super().__init__()
35 | self.model = FireflyGANBase()
36 | self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
37 |
38 | self.model.eval()
39 |
40 | @ torch.inference_mode()
41 | def forward(self, x: torch.Tensor) -> torch.Tensor:
42 | return self.model(x)
43 |
44 | class FireflyGANBase(nn.Module):
45 | def __init__(self):
46 | super().__init__()
47 | self.backbone = ConvNeXtEncoder(**config_dict["backbone"])
48 | self.head = HiFiGANGenerator(**config_dict["head"])
49 |
50 | self.head.checkpointing = False
51 |
52 | @ torch.inference_mode()
53 | def forward(self, x: torch.Tensor) -> torch.Tensor:
54 | x = self.backbone(x)
55 | x = self.head(x)
56 |
57 | return x.squeeze(1)
--------------------------------------------------------------------------------
/vocoders/ffgan/unify.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class UnifyGenerator(nn.Module):
6 | def __init__(
7 | self,
8 | backbone: nn.Module,
9 | head: nn.Module,
10 | vq: nn.Module | None = None,
11 | ):
12 | super().__init__()
13 |
14 | self.backbone = backbone
15 | self.head = head
16 | self.vq = vq
17 |
18 | def forward(self, x: torch.Tensor, template=None) -> torch.Tensor:
19 | x = self.backbone(x)
20 |
21 | if self.vq is not None:
22 | vq_result = self.vq(x)
23 | x = vq_result.z
24 |
25 | x = self.head(x, template=template)
26 |
27 | if x.ndim == 2:
28 | x = x[:, None, :]
29 |
30 | if self.vq is not None:
31 | return x, vq_result
32 |
33 | return x
34 |
35 | def encode(self, x: torch.Tensor) -> torch.Tensor:
36 | if self.vq is None:
37 | raise ValueError("VQ module is not present in the model.")
38 |
39 | x = self.backbone(x)
40 | vq_result = self.vq(x)
41 | return vq_result.codes
42 |
43 | def decode(self, codes: torch.Tensor, template=None) -> torch.Tensor:
44 | if self.vq is None:
45 | raise ValueError("VQ module is not present in the model.")
46 |
47 | x = self.vq.from_codes(codes)[0]
48 | x = self.head(x, template=template)
49 |
50 | if x.ndim == 2:
51 | x = x[:, None, :]
52 |
53 | return x
54 |
55 | def remove_parametrizations(self):
56 | if hasattr(self.backbone, "remove_parametrizations"):
57 | self.backbone.remove_parametrizations()
58 |
59 | if hasattr(self.head, "remove_parametrizations"):
60 | self.head.remove_parametrizations()
--------------------------------------------------------------------------------
/vocoders/pretrained/.keep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/pretrained/.keep
--------------------------------------------------------------------------------
/vocoders/vocos/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Vocos for StableTTS
4 |
5 | Modified from the official implementation of [Vocos](https://github.com/gemelo-ai/vocos/tree/main).
6 |
7 |
8 |
9 | ## Introduction
10 |
11 | Vocos is a fast neural vocoder designed to synthesize audio waveforms from acoustic features. Trained using a Generative Adversarial Network (GAN) objective, Vocos can generate waveforms in a single forward pass. Unlike other typical GAN-based vocoders, Vocos does not model audio samples in the time domain. Instead, it generates spectral coefficients, facilitating rapid audio reconstruction through inverse Fourier transform.
12 |
13 |
14 | ## Inference
15 |
16 | For detailed inference instructions, please refer to `inference.ipynb`
17 |
18 | ## Training
19 |
20 | Setting up and training your model with Vocos is straightforward. Follow these steps to get started:
21 |
22 | ### Preparing Your Data
23 |
24 | 1. **Configure Data Settings**: Update the `DataConfig` in `preprocess.py`. Specifically, adjust the audio_dir to point to your collection of audio files.
25 |
26 | 2. **Run Preprocessing**: Run `preprocess.py`. This script will search (glob) for all audio files in the specified directory, resample them to the target sample_rate (modifiable in config.py), and generate a file list for training.
27 |
28 | ### Start training
29 |
30 | 1. **Adjust Training Configuration**: Edit `TrainConfig` in `config.py` to specify the file list path and tweak training hyperparameters to your needs.
31 |
32 | 2. **Start the Training Process**: Launch `train.py` to begin training your model.
33 |
34 | ### Experiment with Configurations
35 |
36 | Feel free to explore and modify settings in `config.py` to modify the hyperparameters of vocos!
37 |
38 |
39 | ## References
40 |
41 | [Vocos](https://github.com/gemelo-ai/vocos/tree/main)
--------------------------------------------------------------------------------
/vocoders/vocos/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/__init__.py
--------------------------------------------------------------------------------
/vocoders/vocos/config.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 | @dataclass
4 | class MelConfig:
5 | sample_rate: int = 44100
6 | n_fft: int = 2048
7 | win_length: int = 2048
8 | hop_length: int = 512
9 | f_min: float = 0.0
10 | f_max: float = None
11 | pad: int = 0
12 | n_mels: int = 128
13 | center: bool = False
14 | pad_mode: str = "reflect"
15 | mel_scale: str = "slaney"
16 |
17 | def __post_init__(self):
18 | if self.pad == 0:
19 | self.pad = (self.n_fft - self.hop_length) // 2
20 |
21 | @dataclass
22 | class VocosConfig:
23 | input_channels: int = 128
24 | dim: int = 768
25 | intermediate_dim: int = 2048
26 | num_layers: int = 12
27 |
28 | @dataclass
29 | class TrainConfig:
30 | train_dataset_path: str = './filelists/filelist.txt'
31 | test_dataset_path: str = './filelists/filelist.txt'
32 | batch_size: int = 32
33 | learning_rate: float = 1e-4
34 | num_epochs: int = 10000
35 | model_save_path: str = './checkpoints'
36 | log_dir: str = './runs'
37 | log_interval: int = 64
38 | warmup_steps: int = 200
39 |
40 | segment_size = 20480
41 | mel_loss_factor = 15
--------------------------------------------------------------------------------
/vocoders/vocos/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import asdict
3 | import torch
4 | import torchaudio
5 | from torch.utils.data import Dataset
6 |
7 | from utils.audio import LogMelSpectrogram
8 | from config import MelConfig
9 |
10 | class VocosDataset(Dataset):
11 | def __init__(self, filelist_path, segment_size: int, mel_config: MelConfig):
12 | self.filelist_path = filelist_path
13 | self.segment_size = segment_size
14 | self.sample_rate = mel_config.sample_rate
15 | self.mel_extractor = LogMelSpectrogram(**asdict(mel_config))
16 |
17 | self.filelist = self._load_filelist(filelist_path)
18 |
19 | def _load_filelist(self, filelist_path):
20 | if os.path.isdir(filelist_path):
21 | print('scanning dir to get audio files')
22 | filelist = find_audio_files(filelist_path)
23 | else:
24 | with open(filelist_path, 'r', encoding='utf-8') as f:
25 | filelist = [line.strip() for line in f if os.path.exists(line.strip())]
26 | return filelist
27 |
28 | def __len__(self):
29 | return len(self.filelist)
30 |
31 | def __getitem__(self, idx):
32 | audio = load_and_pad_audio(self.filelist[idx], self.sample_rate, self.segment_size)
33 | start_index = torch.randint(0, audio.size(-1) - self.segment_size + 1, (1,)).item()
34 | audio = audio[:, start_index:start_index + self.segment_size] # shape: [1, segment_size]
35 | mel = self.mel_extractor(audio).squeeze(0) # shape: [n_mels, segment_size // hop_length]
36 | return audio, mel
37 |
38 | def load_and_pad_audio(audio_path, target_sr, segment_size):
39 | y, sr = torchaudio.load(audio_path)
40 | if y.size(0) > 1:
41 | y = y[0, :].unsqueeze(0)
42 | if sr != target_sr:
43 | y = torchaudio.functional.resample(y, sr, target_sr)
44 | if y.size(-1) < segment_size:
45 | y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0)
46 | return y
47 |
48 | def find_audio_files(directory):
49 | audio_files = []
50 | valid_extensions = ('.wav', '.ogg', '.opus', '.mp3', '.flac')
51 |
52 | for root, dirs, files in os.walk(directory):
53 | for file in files:
54 | if file.endswith(valid_extensions):
55 | audio_files.append(os.path.join(root, file))
56 |
57 | return audio_files
--------------------------------------------------------------------------------
/vocoders/vocos/inference.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import torch\n",
10 | "import torchaudio\n",
11 | "from IPython.display import Audio, display\n",
12 | "\n",
13 | "from models.model import Vocos\n",
14 | "from utils.audio import LogMelSpectrogram\n",
15 | "from config import MelConfig, VocosConfig\n",
16 | "\n",
17 | "from pathlib import Path\n",
18 | "from dataclasses import asdict\n",
19 | "import random\n",
20 | "\n",
21 | "def load_and_resample_audio(audio_path, target_sr):\n",
22 | " y, sr = torchaudio.load(audio_path)\n",
23 | " if y.size(0) > 1:\n",
24 | " y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]\n",
25 | " if sr != target_sr:\n",
26 | " y = torchaudio.functional.resample(y, sr, target_sr)\n",
27 | " return y\n",
28 | "\n",
29 | "device = 'cpu'\n",
30 | "\n",
31 | "mel_config = MelConfig()\n",
32 | "vocos_config = VocosConfig()\n",
33 | "\n",
34 | "mel_extractor = LogMelSpectrogram(**asdict(mel_config))\n",
35 | "model = Vocos(vocos_config, mel_config).to(device)\n",
36 | "model.load_state_dict(torch.load('./checkpoints/generator_0.pt', map_location='cpu'))\n",
37 | "model.eval()\n",
38 | "\n",
39 | "audio_paths = list(Path('./audios').rglob('*.wav'))"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "audio_path = random.choice(audio_paths)\n",
49 | "with torch.inference_mode():\n",
50 | " audio = load_and_resample_audio(audio_path, mel_config.sample_rate).to(device)\n",
51 | " mel = mel_extractor(audio)\n",
52 | " recon_audio = model(mel)\n",
53 | "display(Audio(audio, rate=mel_config.sample_rate))\n",
54 | "display(Audio(recon_audio, rate=mel_config.sample_rate))"
55 | ]
56 | }
57 | ],
58 | "metadata": {
59 | "kernelspec": {
60 | "display_name": "lxn_vits",
61 | "language": "python",
62 | "name": "python3"
63 | },
64 | "language_info": {
65 | "codemirror_mode": {
66 | "name": "ipython",
67 | "version": 3
68 | },
69 | "file_extension": ".py",
70 | "mimetype": "text/x-python",
71 | "name": "python",
72 | "nbconvert_exporter": "python",
73 | "pygments_lexer": "ipython3",
74 | "version": "3.12.4"
75 | }
76 | },
77 | "nbformat": 4,
78 | "nbformat_minor": 2
79 | }
80 |
--------------------------------------------------------------------------------
/vocoders/vocos/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/models/__init__.py
--------------------------------------------------------------------------------
/vocoders/vocos/models/backbone.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from .module import ConvNeXtBlock
7 |
8 | class VocosBackbone(nn.Module):
9 | """
10 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
11 |
12 | Args:
13 | input_channels (int): Number of input features channels.
14 | dim (int): Hidden dimension of the model.
15 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
16 | num_layers (int): Number of ConvNeXtBlock layers.
17 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
18 | """
19 |
20 | def __init__(
21 | self,
22 | input_channels: int,
23 | dim: int,
24 | intermediate_dim: int,
25 | num_layers: int,
26 | layer_scale_init_value: Optional[float] = None,
27 | ):
28 | super().__init__()
29 | self.input_channels = input_channels
30 | self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
31 | self.norm = nn.LayerNorm(dim, eps=1e-6)
32 | layer_scale_init_value = layer_scale_init_value or 1 / num_layers
33 | self.convnext = nn.ModuleList(
34 | [
35 | ConvNeXtBlock(
36 | dim=dim,
37 | intermediate_dim=intermediate_dim,
38 | layer_scale_init_value=layer_scale_init_value,
39 | )
40 | for _ in range(num_layers)
41 | ]
42 | )
43 | self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
44 | self.apply(self._init_weights)
45 |
46 | def _init_weights(self, m):
47 | if isinstance(m, (nn.Conv1d, nn.Linear)):
48 | nn.init.trunc_normal_(m.weight, std=0.02)
49 | nn.init.constant_(m.bias, 0)
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | x = self.embed(x)
53 | x = self.norm(x.transpose(1, 2)).transpose(1, 2)
54 | for conv_block in self.convnext:
55 | x = conv_block(x)
56 | x = self.final_layer_norm(x.transpose(1, 2))
57 | return x
--------------------------------------------------------------------------------
/vocoders/vocos/models/discriminator.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple
2 |
3 | import torch
4 | from torch import nn
5 | from torch import Tensor
6 | from torch.nn import Conv2d
7 | from torch.nn.utils.parametrizations import weight_norm
8 | from torchaudio.transforms import Spectrogram
9 |
10 |
11 | class MultiPeriodDiscriminator(nn.Module):
12 | def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11)):
13 | super().__init__()
14 | self.discriminators = nn.ModuleList([DiscriminatorP(period=p) for p in periods])
15 |
16 | def forward(self, y: Tensor, y_hat: Tensor):
17 | y_d_rs = []
18 | y_d_gs = []
19 | fmap_rs = []
20 | fmap_gs = []
21 | for d in self.discriminators:
22 | y_d_r, fmap_r = d(y)
23 | y_d_g, fmap_g = d(y_hat)
24 | y_d_rs.append(y_d_r)
25 | fmap_rs.append(fmap_r)
26 | y_d_gs.append(y_d_g)
27 | fmap_gs.append(fmap_g)
28 |
29 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
30 |
31 |
32 | class DiscriminatorP(nn.Module):
33 | def __init__(
34 | self,
35 | period: int,
36 | in_channels: int = 1,
37 | kernel_size: int = 5,
38 | stride: int = 3,
39 | lrelu_slope: float = 0.1,
40 | ):
41 | super().__init__()
42 | self.period = period
43 | self.convs = nn.ModuleList(
44 | [
45 | weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
46 | weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
47 | weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
48 | weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))),
49 | weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))),
50 | ]
51 | )
52 |
53 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
54 | self.lrelu_slope = lrelu_slope
55 |
56 | def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]:
57 | fmap = []
58 | # 1d to 2d
59 | b, c, t = x.shape
60 | if t % self.period != 0: # pad first
61 | n_pad = self.period - (t % self.period)
62 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
63 | t = t + n_pad
64 | x = x.view(b, c, t // self.period, self.period)
65 |
66 | for i, l in enumerate(self.convs):
67 | x = l(x)
68 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope, inplace=True)
69 | if i > 0:
70 | fmap.append(x)
71 | x = self.conv_post(x)
72 | fmap.append(x)
73 | x = torch.flatten(x, 1, -1)
74 |
75 | return x, fmap
76 |
77 |
78 | class MultiResolutionDiscriminator(nn.Module):
79 | def __init__(
80 | self,
81 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
82 | ):
83 | """
84 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
85 |
86 | Args:
87 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
88 | """
89 |
90 | super().__init__()
91 | self.discriminators = nn.ModuleList(
92 | [DiscriminatorR(window_length=w) for w in fft_sizes]
93 | )
94 |
95 | def forward(self, y: Tensor, y_hat: Tensor) -> Tuple[List[Tensor], List[Tensor], List[List[Tensor]], List[List[Tensor]]]:
96 | y_d_rs = []
97 | y_d_gs = []
98 | fmap_rs = []
99 | fmap_gs = []
100 |
101 | for d in self.discriminators:
102 | y_d_r, fmap_r = d(x=y)
103 | y_d_g, fmap_g = d(x=y_hat)
104 | y_d_rs.append(y_d_r)
105 | fmap_rs.append(fmap_r)
106 | y_d_gs.append(y_d_g)
107 | fmap_gs.append(fmap_g)
108 |
109 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs
110 |
111 |
112 | class DiscriminatorR(nn.Module):
113 | def __init__(
114 | self,
115 | window_length: int,
116 | channels: int = 32,
117 | hop_factor: float = 0.25,
118 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
119 | ):
120 | super().__init__()
121 | self.window_length = window_length
122 | self.hop_factor = hop_factor
123 | self.spec_fn = Spectrogram(
124 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
125 | )
126 | n_fft = window_length // 2 + 1
127 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
128 | self.bands = bands
129 | convs = lambda: nn.ModuleList(
130 | [
131 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
132 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
133 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
134 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
135 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
136 | ]
137 | )
138 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
139 |
140 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
141 |
142 | def spectrogram(self, x):
143 | x = x.squeeze(1)
144 |
145 | # x = x - x.mean(dim=-1, keepdims=True)
146 | # # Peak normalize the volume of input audio
147 | # x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
148 |
149 | x = self.spec_fn(x)
150 | x = torch.view_as_real(x)
151 | x = x.permute(0, 3, 2, 1) # b f t c -> b c t f
152 | # Split into bands
153 | x_bands = [x[..., b[0] : b[1]] for b in self.bands]
154 | return x_bands
155 |
156 | def forward(self, x: Tensor):
157 | x_bands = self.spectrogram(x)
158 | fmap = []
159 | x = []
160 | for band, stack in zip(x_bands, self.band_convs):
161 | for i, layer in enumerate(stack):
162 | band = layer(band)
163 | band = torch.nn.functional.leaky_relu(band, 0.1, inplace=True)
164 | if i > 0:
165 | fmap.append(band)
166 | x.append(band)
167 | x = torch.cat(x, dim=-1)
168 | x = self.conv_post(x)
169 | fmap.append(x)
170 |
171 | return x, fmap
--------------------------------------------------------------------------------
/vocoders/vocos/models/head.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ISTFT(nn.Module):
6 | """
7 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
8 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
9 | See issue: https://github.com/pytorch/pytorch/issues/62323
10 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
11 | The NOLA constraint is met as we trim padded samples anyway.
12 |
13 | Args:
14 | n_fft (int): Size of Fourier transform.
15 | hop_length (int): The distance between neighboring sliding window frames.
16 | win_length (int): The size of window frame and STFT filter.
17 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
18 | """
19 |
20 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"):
21 | super().__init__()
22 | if padding not in ["center", "same"]:
23 | raise ValueError("Padding must be 'center' or 'same'.")
24 | self.padding = padding
25 | self.n_fft = n_fft
26 | self.hop_length = hop_length
27 | self.win_length = win_length
28 | window = torch.hann_window(win_length)
29 | self.register_buffer("window", window)
30 |
31 | def forward(self, spec: torch.Tensor) -> torch.Tensor:
32 | """
33 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
34 |
35 | Args:
36 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
37 | N is the number of frequency bins, and T is the number of time frames.
38 |
39 | Returns:
40 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
41 | """
42 | if self.padding == "center":
43 | # Fallback to pytorch native implementation
44 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True)
45 | elif self.padding == "same":
46 | pad = (self.win_length - self.hop_length) // 2
47 | else:
48 | raise ValueError("Padding must be 'center' or 'same'.")
49 |
50 | assert spec.dim() == 3, "Expected a 3D tensor as input"
51 | B, N, T = spec.shape
52 |
53 | # Inverse FFT
54 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
55 | ifft = ifft * self.window[None, :, None]
56 |
57 | # Overlap and Add
58 | output_size = (T - 1) * self.hop_length + self.win_length
59 | y = torch.nn.functional.fold(
60 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
61 | )[:, 0, 0, pad:-pad]
62 |
63 | # Window envelope
64 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
65 | window_envelope = torch.nn.functional.fold(
66 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length),
67 | ).squeeze()[pad:-pad]
68 |
69 | # Normalize
70 | assert (window_envelope > 1e-11).all()
71 | y = y / window_envelope
72 |
73 | return y
74 |
75 | class ISTFTHead(nn.Module):
76 | """
77 | ISTFT Head module for predicting STFT complex coefficients.
78 |
79 | Args:
80 | dim (int): Hidden dimension of the model.
81 | n_fft (int): Size of Fourier transform.
82 | hop_length (int): The distance between neighboring sliding window frames, which should align with
83 | the resolution of the input features.
84 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
85 | """
86 |
87 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
88 | super().__init__()
89 | out_dim = n_fft + 2
90 | self.out = torch.nn.Linear(dim, out_dim)
91 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding)
92 |
93 | def forward(self, x: torch.Tensor) -> torch.Tensor:
94 | """
95 | Forward pass of the ISTFTHead module.
96 |
97 | Args:
98 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
99 | L is the sequence length, and H denotes the model dimension.
100 |
101 | Returns:
102 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
103 | """
104 | x = self.out(x).transpose(1, 2)
105 | mag, p = x.chunk(2, dim=1)
106 | mag = torch.exp(mag)
107 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes
108 | # wrapping happens here. These two lines produce real and imaginary value
109 | x = torch.cos(p)
110 | y = torch.sin(p)
111 | # recalculating phase here does not produce anything new
112 | # only costs time
113 | # phase = torch.atan2(y, x)
114 | # S = mag * torch.exp(phase * 1j)
115 | # better directly produce the complex value
116 | S = mag * (x + 1j * y)
117 | audio = self.istft(S)
118 | return audio
--------------------------------------------------------------------------------
/vocoders/vocos/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from typing import List
4 | from dataclasses import asdict
5 |
6 | from utils.audio import LogMelSpectrogram
7 | from config import MelConfig
8 |
9 | # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
10 | class MultiScaleMelSpectrogramLoss(nn.Module):
11 | def __init__(self, n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048]):
12 | super().__init__()
13 | assert len(n_mels) == len(window_lengths), "n_mels and window_lengths must have the same length"
14 | self.mel_transforms = nn.ModuleList(self._get_transforms(n_mels, window_lengths))
15 | self.loss_fn = nn.L1Loss()
16 |
17 | def _get_transforms(self, n_mels, window_lengths):
18 | transforms = []
19 | for n_mel, win_length in zip(n_mels, window_lengths):
20 | transform = LogMelSpectrogram(**asdict(MelConfig(n_mels=n_mel, n_fft=win_length, win_length=win_length, hop_length=win_length//4)))
21 | transforms.append(transform)
22 | return transforms
23 |
24 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
25 | return sum(self.loss_fn(mel_transform(x), mel_transform(y)) for mel_transform in self.mel_transforms)
26 |
27 | class SingleScaleMelSpectrogramLoss(nn.Module):
28 | def __init__(self):
29 | super().__init__()
30 | self.mel_transform = LogMelSpectrogram(**asdict(MelConfig()))
31 | self.loss_fn = nn.L1Loss()
32 | print('using single mel loss')
33 |
34 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
35 | return self.loss_fn(self.mel_transform(x), self.mel_transform(y))
36 |
37 | def feature_loss(fmap_r, fmap_g):
38 | loss = 0
39 | for dr, dg in zip(fmap_r, fmap_g):
40 | for rl, gl in zip(dr, dg):
41 | loss += torch.mean(torch.abs(rl - gl))
42 |
43 | return loss*2
44 |
45 | def discriminator_loss(disc_real_outputs, disc_generated_outputs):
46 | loss = 0
47 | r_losses = []
48 | g_losses = []
49 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
50 | r_loss = torch.mean((1-dr)**2)
51 | g_loss = torch.mean(dg**2)
52 | loss += (r_loss + g_loss)
53 | r_losses.append(r_loss.item())
54 | g_losses.append(g_loss.item())
55 |
56 | return loss, r_losses, g_losses
57 |
58 | def generator_loss(disc_outputs):
59 | loss = 0
60 | gen_losses = []
61 | for dg in disc_outputs:
62 | l = torch.mean((1-dg)**2)
63 | gen_losses.append(l)
64 | loss += l
65 |
66 | return loss, gen_losses
--------------------------------------------------------------------------------
/vocoders/vocos/models/model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, asdict
2 |
3 | import torch
4 | from torch import nn
5 | from torch import Tensor
6 |
7 | from .head import ISTFTHead
8 | from .backbone import VocosBackbone
9 | from config import MelConfig, VocosConfig
10 |
11 | class Vocos(nn.Module):
12 | def __init__(self, vocos_config: VocosConfig, mel_config: MelConfig):
13 | super().__init__()
14 | self.backbone = VocosBackbone(**asdict(vocos_config))
15 | self.head = ISTFTHead(vocos_config.dim, mel_config.n_fft, mel_config.hop_length)
16 |
17 | def forward(self, x: Tensor) -> Tensor:
18 | x = self.backbone(x)
19 | x = self.head(x)
20 | return x
21 |
--------------------------------------------------------------------------------
/vocoders/vocos/models/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ConvNeXtBlock(nn.Module):
6 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
7 |
8 | Args:
9 | dim (int): Number of input channels.
10 | intermediate_dim (int): Dimensionality of the intermediate layer.
11 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
12 | Defaults to None.
13 | """
14 |
15 | def __init__(
16 | self,
17 | dim: int,
18 | intermediate_dim: int,
19 | layer_scale_init_value: float,
20 | ):
21 | super().__init__()
22 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
23 | self.norm = nn.LayerNorm(dim, eps=1e-6)
24 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
25 | self.act = nn.GELU()
26 | self.pwconv2 = nn.Linear(intermediate_dim, dim)
27 | self.gamma = (
28 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
29 | if layer_scale_init_value > 0
30 | else None
31 | )
32 |
33 | def forward(self, x: torch.Tensor) -> torch.Tensor:
34 | residual = x
35 | x = self.dwconv(x)
36 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
37 | x = self.norm(x)
38 | x = self.pwconv1(x)
39 | x = self.act(x)
40 | x = self.pwconv2(x)
41 | if self.gamma is not None:
42 | x = self.gamma * x
43 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
44 |
45 | x = residual + x
46 | return x
47 |
--------------------------------------------------------------------------------
/vocoders/vocos/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import concurrent.futures
3 |
4 | from tqdm import tqdm
5 | from dataclasses import dataclass
6 |
7 | @dataclass
8 | class DataConfig:
9 | audio_dirs = ['./datasets'] # paths to audios
10 | filelist_path = './filelists/filelist.txt' # path to save filelist
11 | audio_formats = ('.wav', '.ogg', '.opus', '.mp3', '.flac')
12 |
13 | data_config = DataConfig()
14 |
15 | filelist_path = data_config.filelist_path
16 |
17 | os.makedirs(os.path.dirname(filelist_path), exist_ok=True)
18 |
19 | def find_audio_files(directory) -> list:
20 | audio_files = []
21 | valid_extensions = data_config.audio_formats
22 |
23 | for root, dirs, files in tqdm(os.walk(directory)):
24 | audio_files.extend(os.path.join(root, file) for file in files if file.endswith(valid_extensions))
25 |
26 | return audio_files
27 |
28 |
29 | def main():
30 | results = []
31 |
32 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor:
33 | futures = [executor.submit(find_audio_files, audio_dir) for audio_dir in data_config.audio_dirs]
34 |
35 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
36 | results.extend(future.result())
37 |
38 | # save filelist
39 | with open(filelist_path, 'w', encoding='utf-8') as f:
40 | f.writelines(f"{result}\n" for result in results)
41 |
42 | print(f"filelist has been saved to {filelist_path}")
43 |
44 | if __name__ == '__main__':
45 | main()
--------------------------------------------------------------------------------
/vocoders/vocos/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorboard
2 | tqdm
--------------------------------------------------------------------------------
/vocoders/vocos/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.distributed as dist
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from torch.utils.data import DataLoader
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | from tqdm import tqdm
14 | import itertools
15 | from dataclasses import asdict
16 |
17 | from models.model import Vocos
18 | from dataset import VocosDataset
19 | from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator
20 | from models.loss import feature_loss, generator_loss, discriminator_loss, MultiScaleMelSpectrogramLoss, SingleScaleMelSpectrogramLoss
21 | from config import MelConfig, VocosConfig, TrainConfig
22 | from utils.scheduler import get_cosine_schedule_with_warmup
23 | from utils.load import continue_training
24 |
25 | torch.backends.cudnn.benchmark = True
26 |
27 | def setup(rank, world_size):
28 | os.environ['MASTER_ADDR'] = 'localhost'
29 | os.environ['MASTER_PORT'] = '12345'
30 | dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size)
31 |
32 | def cleanup():
33 | dist.destroy_process_group()
34 |
35 | def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig):
36 | if vocos_config.input_channels != mel_config.n_mels:
37 | raise ValueError("input_channels and n_mels must be equal.")
38 |
39 | if not os.path.exists(train_config.model_save_path):
40 | print(f'Creating {train_config.model_save_path}')
41 | os.makedirs(train_config.model_save_path, exist_ok=True)
42 |
43 | def train(rank, world_size):
44 | setup(rank, world_size)
45 | torch.cuda.set_device(rank)
46 |
47 | vocos_config = VocosConfig()
48 | mel_config = MelConfig()
49 | train_config = TrainConfig()
50 |
51 | _init_config(vocos_config, mel_config, train_config)
52 |
53 | generator = Vocos(vocos_config, mel_config).to(rank)
54 | mpd = MultiPeriodDiscriminator().to(rank)
55 | mrd = MultiResolutionDiscriminator().to(rank)
56 | loss_fn = MultiScaleMelSpectrogramLoss().to(rank)
57 | if rank == 0:
58 | print(f"Generator params: {sum(p.numel() for p in generator.parameters()) / 1e6}")
59 | print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters()) / 1e6}")
60 | print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters()) / 1e6}")
61 |
62 | generator = DDP(generator, device_ids=[rank])
63 | mpd = DDP(mpd, device_ids=[rank])
64 | mrd = DDP(mrd, device_ids=[rank])
65 |
66 | train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config)
67 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
68 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False, persistent_workers=True)
69 |
70 | if rank == 0:
71 | writer = SummaryWriter(train_config.log_dir)
72 |
73 | optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate)
74 | optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate)
75 | scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
76 | scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader))
77 |
78 | # load latest checkpoints if possible
79 | current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g)
80 |
81 | generator.train()
82 | mpd.train()
83 | mrd.train()
84 | for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times
85 | train_dataloader.sampler.set_epoch(epoch)
86 | if rank == 0:
87 | dataloader = tqdm(train_dataloader)
88 | else:
89 | dataloader = train_dataloader
90 |
91 | for batch_idx, datas in enumerate(dataloader):
92 | datas = [data.to(rank, non_blocking=True) for data in datas]
93 | audios, mels = datas
94 | audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size]
95 | optimizer_d.zero_grad()
96 |
97 | # MPD
98 | y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach())
99 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
100 |
101 | # MRD
102 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach())
103 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
104 |
105 | loss_disc_all = loss_disc_s + loss_disc_f
106 | loss_disc_all.backward()
107 |
108 | grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000)
109 | grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000)
110 | optimizer_d.step()
111 | scheduler_d.step()
112 |
113 | # generator
114 | optimizer_g.zero_grad()
115 | loss_mel = loss_fn(audios, audios_fake) * train_config.mel_loss_factor
116 |
117 | # MPD loss
118 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake)
119 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
120 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
121 |
122 | # MRD loss
123 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake)
124 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
125 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
126 |
127 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
128 | loss_gen_all.backward()
129 |
130 | grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000)
131 | optimizer_g.step()
132 | scheduler_g.step()
133 |
134 | if rank == 0 and batch_idx % train_config.log_interval == 0:
135 | steps = epoch * len(dataloader) + batch_idx
136 | writer.add_scalar("training/gen_loss_total", loss_gen_all, steps)
137 | writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
138 | writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps)
139 | writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps)
140 | writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps)
141 | writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps)
142 | writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
143 | writer.add_scalar("training/mel_loss", loss_mel.item(), steps)
144 | writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps)
145 | writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps)
146 | writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps)
147 | writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
148 | writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
149 |
150 | if rank == 0:
151 | torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt'))
152 | torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt'))
153 | torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt'))
154 | torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt'))
155 | torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt'))
156 | print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}")
157 |
158 | cleanup()
159 |
160 | torch.set_num_threads(1)
161 | torch.set_num_interop_threads(1)
162 |
163 | if __name__ == "__main__":
164 | world_size = torch.cuda.device_count()
165 | torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size)
--------------------------------------------------------------------------------
/vocoders/vocos/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/utils/__init__.py
--------------------------------------------------------------------------------
/vocoders/vocos/utils/audio.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | import torchaudio
5 |
6 | class LinearSpectrogram(nn.Module):
7 | def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode):
8 | super().__init__()
9 |
10 | self.n_fft = n_fft
11 | self.win_length = win_length
12 | self.hop_length = hop_length
13 | self.pad = pad
14 | self.center = center
15 | self.pad_mode = pad_mode
16 |
17 | self.register_buffer("window", torch.hann_window(win_length))
18 |
19 | def forward(self, waveform: Tensor) -> Tensor:
20 | if waveform.ndim == 3:
21 | waveform = waveform.squeeze(1)
22 | waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1)
23 | spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True)
24 | spec = torch.view_as_real(spec)
25 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
26 | return spec
27 |
28 |
29 | class LogMelSpectrogram(nn.Module):
30 | def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale):
31 | super().__init__()
32 | self.sample_rate = sample_rate
33 | self.n_fft = n_fft
34 | self.win_length = win_length
35 | self.hop_length = hop_length
36 | self.f_min = f_min
37 | self.f_max = f_max
38 | self.pad = pad
39 | self.n_mels = n_mels
40 | self.center = center
41 | self.pad_mode = pad_mode
42 | self.mel_scale = mel_scale
43 |
44 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode)
45 | self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale)
46 |
47 | def compress(self, x: Tensor) -> Tensor:
48 | return torch.log(torch.clamp(x, min=1e-5))
49 |
50 | def decompress(self, x: Tensor) -> Tensor:
51 | return torch.exp(x)
52 |
53 | def forward(self, x: Tensor) -> Tensor:
54 | linear_spec = self.spectrogram(x)
55 | x = self.mel_scale(linear_spec)
56 | x = self.compress(x)
57 | return x
58 |
59 | def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor:
60 | try:
61 | y, sr = torchaudio.load(audio_path)
62 | except Exception as e:
63 | print(str(e))
64 | return None
65 |
66 | y.to(device)
67 | # Convert to mono
68 | if y.size(0) > 1:
69 | y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]
70 |
71 | # resample audio to target sample_rate
72 | if sr != target_sr:
73 | y = torchaudio.functional.resample(y, sr, target_sr)
74 | return y
--------------------------------------------------------------------------------
/vocoders/vocos/utils/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.nn.parallel import DistributedDataParallel as DDP
6 |
7 | def continue_training(checkpoint_path, generator: DDP, mpd: DDP, mrd: DDP, optimizer_d: optim.Optimizer, optimizer_g: optim.Optimizer) -> int:
8 | """load the latest checkpoints and optimizers"""
9 | generator_dict = {}
10 | mpd_dict = {}
11 | mrd_dict = {}
12 | optimizer_d_dict = {}
13 | optimizer_g_dict = {}
14 |
15 | # globt all the checkpoints in the directory
16 | for file in os.listdir(checkpoint_path):
17 | if file.endswith(".pt"):
18 | name, epoch_str = file.rsplit('_', 1)
19 | epoch = int(epoch_str.split('.')[0])
20 |
21 | if name.startswith("generator"):
22 | generator_dict[epoch] = file
23 | elif name.startswith("mpd"):
24 | mpd_dict[epoch] = file
25 | elif name.startswith("mrd"):
26 | mrd_dict[epoch] = file
27 | elif name.startswith("optimizerd"):
28 | optimizer_d_dict[epoch] = file
29 | elif name.startswith("optimizerg"):
30 | optimizer_g_dict[epoch] = file
31 |
32 | # get the largest epoch
33 | common_epochs = set(generator_dict.keys()) & set(mpd_dict.keys()) & set(mrd_dict.keys()) & set(optimizer_d_dict.keys()) & set(optimizer_g_dict.keys())
34 | if common_epochs:
35 | max_epoch = max(common_epochs)
36 | generator_path = os.path.join(checkpoint_path, generator_dict[max_epoch])
37 | mpd_path = os.path.join(checkpoint_path, mpd_dict[max_epoch])
38 | mrd_path = os.path.join(checkpoint_path, mrd_dict[max_epoch])
39 | optimizer_d_path = os.path.join(checkpoint_path, optimizer_d_dict[max_epoch])
40 | optimizer_g_path = os.path.join(checkpoint_path, optimizer_g_dict[max_epoch])
41 |
42 | # load model and optimizer
43 | generator.module.load_state_dict(torch.load(generator_path, map_location='cpu'))
44 | mpd.module.load_state_dict(torch.load(mpd_path, map_location='cpu'))
45 | mrd.module.load_state_dict(torch.load(mrd_path, map_location='cpu'))
46 | optimizer_d.load_state_dict(torch.load(optimizer_d_path, map_location='cpu'))
47 | optimizer_g.load_state_dict(torch.load(optimizer_g_path, map_location='cpu'))
48 |
49 | print(f'resume model and optimizer from {max_epoch} epoch')
50 | return max_epoch + 1
51 |
52 | else:
53 | return 0
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TMPDIR'] = './temps' # avoid the system default temp folder not having access permissions
3 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # use huggingfacae mirror for users that could not login to huggingface
4 |
5 | import re
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 |
9 | import torch
10 | import gradio as gr
11 |
12 | from api import StableTTSAPI
13 |
14 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
15 |
16 | tts_model_path = './checkpoints/checkpoint_0.pt'
17 | vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt'
18 | vocoder_type = 'ffgan'
19 |
20 | model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type).to(device)
21 |
22 | @ torch.inference_mode()
23 | def inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg):
24 | text = remove_newlines_after_punctuation(text)
25 |
26 | if language == 'chinese':
27 | text = text.replace(' ', '')
28 |
29 | audio, mel = model.inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg)
30 |
31 | max_val = torch.max(torch.abs(audio))
32 | if max_val > 1:
33 | audio = audio / max_val
34 |
35 | audio_output = (model.mel_config.sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio
36 | mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel
37 |
38 | return audio_output, mel_output
39 |
40 | def plot_mel_spectrogram(mel_spectrogram):
41 | plt.close() # prevent memory leak
42 | fig, ax = plt.subplots(figsize=(20, 8))
43 | ax.imshow(mel_spectrogram, aspect='auto', origin='lower')
44 | plt.axis('off')
45 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges
46 | return fig
47 |
48 | def remove_newlines_after_punctuation(text):
49 | pattern = r'([,。!?、“”‘’《》【】;:,.!?\'\"<>()\[\]{}])\n'
50 | return re.sub(pattern, r'\1', text)
51 |
52 | def main():
53 |
54 | # gradio wabui, reference: https://huggingface.co/spaces/fishaudio/fish-speech-1
55 | gui_title = 'StableTTS'
56 | gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3."""
57 | example_text = """你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!"""
58 |
59 | with gr.Blocks(theme=gr.themes.Base()) as demo:
60 | demo.load(None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}")
61 |
62 | with gr.Row():
63 | with gr.Column():
64 | gr.Markdown(f"# {gui_title}")
65 | gr.Markdown(gui_description)
66 |
67 | with gr.Row():
68 | with gr.Column():
69 | input_text_gr = gr.Textbox(
70 | label="Input Text",
71 | info="Put your text here",
72 | value=example_text,
73 | )
74 |
75 | ref_audio_gr = gr.Audio(
76 | label="Reference Audio",
77 | type="filepath"
78 | )
79 |
80 | language_gr = gr.Dropdown(
81 | label='Language',
82 | choices=list(model.supported_languages),
83 | value = 'chinese'
84 | )
85 |
86 | step_gr = gr.Slider(
87 | label='Step',
88 | minimum=1,
89 | maximum=100,
90 | value=25,
91 | step=1
92 | )
93 |
94 | temperature_gr = gr.Slider(
95 | label='Temperature',
96 | minimum=0,
97 | maximum=2,
98 | value=1,
99 | )
100 |
101 | length_scale_gr = gr.Slider(
102 | label='Length_Scale',
103 | minimum=0,
104 | maximum=5,
105 | value=1,
106 | )
107 |
108 | solver_gr = gr.Dropdown(
109 | label='ODE Solver',
110 | choices=['euler', 'midpoint', 'dopri5', 'rk4', 'implicit_adams', 'bosh3', 'fehlberg2', 'adaptive_heun'],
111 | value = 'dopri5'
112 | )
113 |
114 | cfg_gr = gr.Slider(
115 | label='CFG',
116 | minimum=0,
117 | maximum=10,
118 | value=3,
119 | )
120 |
121 | with gr.Column():
122 | mel_gr = gr.Plot(label="Mel Visual")
123 | audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True)
124 | tts_button = gr.Button("\U0001F3A7 Generate / 合成", elem_id="send-btn", visible=True, variant="primary")
125 |
126 | tts_button.click(inference, [input_text_gr, ref_audio_gr, language_gr, step_gr, temperature_gr, length_scale_gr, solver_gr, cfg_gr], outputs=[audio_gr, mel_gr])
127 |
128 | demo.queue()
129 | demo.launch(debug=True, show_api=True)
130 |
131 |
132 | if __name__ == '__main__':
133 | main()
--------------------------------------------------------------------------------