├── .gitignore ├── LICENSE ├── README.md ├── api.py ├── checkpoints └── .keep ├── config.py ├── datas ├── __init__.py ├── dataset.py └── sampler.py ├── figures └── structure.jpg ├── filelists └── example.txt ├── inference.ipynb ├── models ├── __init__.py ├── diffusion_transformer.py ├── duration_predictor.py ├── estimator.py ├── flow_matching.py ├── model.py ├── reference_encoder.py └── text_encoder.py ├── monotonic_align ├── __init__.py └── core.py ├── preprocess.py ├── recipes ├── AiSHELL3.py ├── BZNSYP_标贝女声.py ├── VCTK_huggingface.py ├── genshin_en_小虫哥ver.py ├── genshin_zh_小虫哥ver.py ├── hifi_tts.py └── libriTTS.py ├── requirements.txt ├── text ├── LICENSE ├── __init__.py ├── cleaners.py ├── cn2an │ ├── __init__.py │ ├── an2cn.py │ ├── cn2an.py │ ├── conf.py │ └── transform.py ├── cnm3 │ └── ds_CNM3.txt ├── custom_pypinyin_dict │ ├── __init__.py │ ├── cc_cedict_0.py │ ├── cc_cedict_1.py │ ├── cc_cedict_2.py │ ├── cc_cedict_3.py │ ├── genshin.py │ └── phrase_pinyin_data.py ├── english.py ├── japanese.py ├── mandarin.py └── symbols.py ├── train.py ├── utils ├── __init__.py ├── audio.py ├── load.py ├── mask.py └── scheduler.py ├── vocoders ├── __init__.py ├── ffgan │ ├── __init__.py │ ├── backbone.py │ ├── head.py │ ├── model.py │ └── unify.py ├── pretrained │ └── .keep └── vocos │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── dataset.py │ ├── inference.ipynb │ ├── models │ ├── __init__.py │ ├── backbone.py │ ├── discriminator.py │ ├── head.py │ ├── loss.py │ ├── model.py │ └── module.py │ ├── preprocess.py │ ├── requirements.txt │ ├── train.py │ └── utils │ ├── __init__.py │ ├── audio.py │ ├── load.py │ └── scheduler.py └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 KdaiP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # StableTTS 4 | 5 | Next-generation TTS model using flow-matching and DiT, inspired by [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3). 6 | 7 | 8 |
9 | 10 | ## Introduction 11 | 12 | As the first open-source TTS model that tried to combine flow-matching and DiT, **StableTTS** is a fast and lightweight TTS model for chinese, english and japanese speech generation. It has 31M parameters. 13 | 14 | ✨ **Huggingface demo:** [🤗](https://huggingface.co/spaces/KdaiP/StableTTS1.1) 15 | 16 | ## News 17 | 18 | 2024/10: A new autoregressive TTS model is coming soon... 19 | 20 | 2024/9: 🚀 **StableTTS V1.1 Released** ⭐ Audio quality is largely improved ⭐ 21 | 22 | ⭐ **V1.1 Release Highlights:** 23 | 24 | - Fixed critical issues that cause the audio quality being much lower than expected. (Mainly in Mel spectrogram and Attention mask) 25 | - Introduced U-Net-like long skip connections to the DiT in the Flow-matching Decoder. 26 | - Use cosine timestep scheduler from [Cosyvoice](https://github.com/FunAudioLLM/CosyVoice) 27 | - Add support for CFG (Classifier-Free Guidance). 28 | - Add support for [FireflyGAN vocoder](https://github.com/fishaudio/vocoder/releases/tag/1.0.0). 29 | - Switched to [torchdiffeq](https://github.com/rtqichen/torchdiffeq) for ODE solvers. 30 | - Improved Chinese text frontend (partially based on [gpt-sovits2](https://github.com/RVC-Boss/GPT-SoVITS)). 31 | - Multilingual support (Chinese, English, Japanese) in a single checkpoint. 32 | - Increased parameters: 10M -> 31M. 33 | 34 | 35 | ## Pretrained models 36 | 37 | ### Text-To-Mel model 38 | 39 | Download and place the model in the `./checkpoints` directory, it is ready for inference, finetuning and webui. 40 | 41 | | Model Name | Task Details | Dataset | Download Link | 42 | |:----------:|:------------:|:-------------:|:-------------:| 43 | | StableTTS | text to mel | 600 hours | [🤗](https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/StableTTS/checkpoint_0.pt)| 44 | 45 | ### Mel-To-Wav model 46 | 47 | Choose a vocoder (`vocos` or `firefly-gan` ) and place it in the `./vocoders/pretrained` directory. 48 | 49 | | Model Name | Task Details | Dataset | Download Link | 50 | |:----------:|:------------:|:-------------:|:-------------:| 51 | | Vocos | mel to wav | 2k hours | [🤗](https://huggingface.co/KdaiP/StableTTS1.1/resolve/main/vocoders/vocos.pt)| 52 | | firefly-gan-base | mel to wav | HiFi-16kh | [download from fishaudio](https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt)| 53 | 54 | ## Installation 55 | 56 | 1. **Install pytorch**: Follow the [official PyTorch guide](https://pytorch.org/get-started/locally/) to install pytorch and torchaudio. We recommend the latest version (tested with PyTorch 2.4 and Python 3.12). 57 | 58 | 2. **Install Dependencies**: Run the following command to install the required Python packages: 59 | 60 | ```bash 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | ## Inference 65 | 66 | For detailed inference instructions, please refer to `inference.ipynb` 67 | 68 | We also provide a webui based on gradio, please refer to `webui.py` 69 | 70 | ## Training 71 | 72 | StableTTS is designed to be trained easily. We only need text and audio pairs, without any speaker id or extra feature extraction. Here’s how to get started: 73 | 74 | ### Preparing Your Data 75 | 76 | 1. **Generate Text and Audio pairs**: Generate the text and audio pair filelist as `./filelists/example.txt`. Some recipes of open-source datasets could be found in `./recipes`. 77 | 78 | 2. **Run Preprocessing**: Adjust the `DataConfig` in `preprocess.py` to set your input and output paths, then run the script. This will process the audio and text according to your list, outputting a JSON file with paths to mel features and phonemes. 79 | 80 | **Note: Process multilingual data separately by changing the `language` setting in `DataConfig`** 81 | 82 | ### Start training 83 | 84 | 1. **Adjust Training Configuration**: In `config.py`, modify `TrainConfig` to set your file list path and adjust training parameters (such as batch_size) as needed. 85 | 86 | 2. **Start the Training Process**: Launch `train.py` to start training your model. 87 | 88 | Note: For finetuning, download the pretrained model and place it in the `model_save_path` directory specified in `TrainConfig`. Training script will automatically detect and load the pretrained checkpoint. 89 | 90 | ### (Optional) Vocoder training 91 | 92 | The `./vocoder/vocos` folder contains the training and finetuning codes for vocos vocoder. 93 | 94 | For other types of vocoders, we recommend to train by using [fishaudio vocoder](https://github.com/fishaudio/vocoder): an uniform interface for developing various vocoders. We use the same spectrogram transform so the vocoders trained is compatible with StableTTS. 95 | 96 | ## Model structure 97 | 98 |
99 | 100 |

101 | 102 |

103 | 104 |
105 | 106 | - We use the Diffusion Convolution Transformer block from [Hierspeech++](https://github.com/sh-lee-prml/HierSpeechpp), which is a combination of original [DiT](https://github.com/sh-lee-prml/HierSpeechpp) and [FFT](https://arxiv.org/pdf/1905.09263.pdf)(Feed forward Transformer from fastspeech) for better prosody. 107 | 108 | - In flow-matching decoder, we add a [FiLM layer](https://arxiv.org/abs/1709.07871) before DiT block to condition timestep embedding into model. 109 | 110 | ## References 111 | 112 | The development of our models heavily relies on insights and code from various projects. We express our heartfelt thanks to the creators of the following: 113 | 114 | ### Direct Inspirations 115 | 116 | [Matcha TTS](https://github.com/shivammehta25/Matcha-TTS): Essential flow-matching code. 117 | 118 | [Grad TTS](https://github.com/huawei-noah/Speech-Backbones/tree/main/Grad-TTS): Diffusion model structure. 119 | 120 | [Stable Diffusion 3](https://stability.ai/news/stable-diffusion-3): Idea of combining flow-matching and DiT. 121 | 122 | [Vits](https://github.com/jaywalnut310/vits): Code style and MAS insights, DistributedBucketSampler. 123 | 124 | ### Additional References: 125 | 126 | [plowtts-pytorch](https://github.com/p0p4k/pflowtts_pytorch): codes of MAS in training 127 | 128 | [Bert-VITS2](https://github.com/Plachtaa/VITS-fast-fine-tuning) : numba version of MAS and modern pytorch codes of Vits 129 | 130 | [fish-speech](https://github.com/fishaudio/fish-speech): dataclass usage and mel-spectrogram transforms using torchaudio, gradio webui 131 | 132 | [gpt-sovits](https://github.com/RVC-Boss/GPT-SoVITS): melstyle encoder for voice clone 133 | 134 | [coqui xtts](https://huggingface.co/spaces/coqui/xtts): gradio webui 135 | 136 | Chinese Dirtionary Of DiffSinger: [Multi-langs_Dictionary](https://github.com/colstone/Multi-langs_Dictionary) and [atonyxu's fork](https://github.com/atonyxu/Multi-langs_Dictionary) 137 | 138 | ## TODO 139 | 140 | - [x] Release pretrained models. 141 | - [x] Support Japanese language. 142 | - [x] User friendly preprocess and inference script. 143 | - [x] Enhance documentation and citations. 144 | - [x] Release multilingual checkpoint. 145 | 146 | ## Disclaimer 147 | 148 | Any organization or individual is prohibited from using any technology in this repo to generate or edit someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dataclasses import asdict 5 | 6 | from utils.audio import LogMelSpectrogram 7 | from config import ModelConfig, MelConfig 8 | from models.model import StableTTS 9 | 10 | from text import symbols 11 | from text import cleaned_text_to_sequence 12 | from text.mandarin import chinese_to_cnm3 13 | from text.english import english_to_ipa2 14 | from text.japanese import japanese_to_ipa2 15 | 16 | from datas.dataset import intersperse 17 | from utils.audio import load_and_resample_audio 18 | 19 | def get_vocoder(model_path, model_name='ffgan') -> nn.Module: 20 | if model_name == 'ffgan': 21 | # training or changing ffgan config is not supported in this repo 22 | # you can train your own model at https://github.com/fishaudio/vocoder 23 | from vocoders.ffgan.model import FireflyGANBaseWrapper 24 | vocoder = FireflyGANBaseWrapper(model_path) 25 | 26 | elif model_name == 'vocos': 27 | from vocoders.vocos.models.model import Vocos 28 | from config import VocosConfig, MelConfig 29 | vocoder = Vocos(VocosConfig(), MelConfig()) 30 | vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu')) 31 | vocoder.eval() 32 | 33 | else: 34 | raise NotImplementedError(f"Unsupported model: {model_name}") 35 | 36 | return vocoder 37 | 38 | class StableTTSAPI(nn.Module): 39 | def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'): 40 | super().__init__() 41 | 42 | self.mel_config = MelConfig() 43 | self.tts_model_config = ModelConfig() 44 | 45 | self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config)) 46 | 47 | # text to mel spectrogram 48 | self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config)) 49 | self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True)) 50 | self.tts_model.eval() 51 | 52 | # mel spectrogram to waveform 53 | self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name) 54 | self.vocoder_model.eval() 55 | 56 | self.g2p_mapping = { 57 | 'chinese': chinese_to_cnm3, 58 | 'japanese': japanese_to_ipa2, 59 | 'english': english_to_ipa2, 60 | } 61 | self.supported_languages = self.g2p_mapping.keys() 62 | 63 | @ torch.inference_mode() 64 | def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0): 65 | device = next(self.parameters()).device 66 | phonemizer = self.g2p_mapping.get(language) 67 | 68 | text = phonemizer(text) 69 | text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0) 70 | text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device) 71 | 72 | ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device) 73 | ref_audio = self.mel_extractor(ref_audio) 74 | 75 | mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs'] 76 | audio_output = self.vocoder_model(mel_output) 77 | return audio_output.cpu(), mel_output.cpu() 78 | 79 | def get_params(self): 80 | tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6 81 | vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6 82 | return tts_param, vocoder_param 83 | 84 | if __name__ == '__main__': 85 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 86 | tts_model_path = './checkpoints/checkpoint_0.pt' 87 | vocoder_model_path = './vocoders/pretrained/vocos.pt' 88 | 89 | model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos') 90 | model.to(device) 91 | 92 | text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……' 93 | audio = './audio_1.wav' 94 | 95 | audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3) 96 | print(audio_output.shape) 97 | print(mel_output.shape) 98 | 99 | import torchaudio 100 | torchaudio.save('output.wav', audio_output, MelConfig().sample_rate) 101 | 102 | 103 | -------------------------------------------------------------------------------- /checkpoints/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/checkpoints/.keep -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class MelConfig: 5 | sample_rate: int = 44100 6 | n_fft: int = 2048 7 | win_length: int = 2048 8 | hop_length: int = 512 9 | f_min: float = 0.0 10 | f_max: float = None 11 | pad: int = 0 12 | n_mels: int = 128 13 | center: bool = False 14 | pad_mode: str = "reflect" 15 | mel_scale: str = "slaney" 16 | 17 | def __post_init__(self): 18 | if self.pad == 0: 19 | self.pad = (self.n_fft - self.hop_length) // 2 20 | 21 | @dataclass 22 | class ModelConfig: 23 | hidden_channels: int = 256 24 | filter_channels: int = 1024 25 | n_heads: int = 4 26 | n_enc_layers: int = 3 27 | n_dec_layers: int = 6 28 | kernel_size: int = 3 29 | p_dropout: int = 0.1 30 | gin_channels: int = 256 31 | 32 | @dataclass 33 | class TrainConfig: 34 | train_dataset_path: str = 'filelists/filelist.json' 35 | test_dataset_path: str = 'filelists/filelist.json' # not used 36 | batch_size: int = 32 37 | learning_rate: float = 1e-4 38 | num_epochs: int = 10000 39 | model_save_path: str = './checkpoints' 40 | log_dir: str = './runs' 41 | log_interval: int = 16 42 | save_interval: int = 1 43 | warmup_steps: int = 200 44 | 45 | @dataclass 46 | class VocosConfig: 47 | input_channels: int = 128 48 | dim: int = 512 49 | intermediate_dim: int = 1536 50 | num_layers: int = 8 -------------------------------------------------------------------------------- /datas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/datas/__init__.py -------------------------------------------------------------------------------- /datas/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import json 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | from text import cleaned_text_to_sequence 9 | 10 | def intersperse(lst: list, item: int): 11 | """ 12 | putting a blank token between any two input tokens to improve pronunciation 13 | see https://github.com/jaywalnut310/glow-tts/issues/43 for more details 14 | """ 15 | result = [item] * (len(lst) * 2 + 1) 16 | result[1::2] = lst 17 | return result 18 | 19 | class StableDataset(Dataset): 20 | def __init__(self, filelist_path, hop_length): 21 | self.filelist_path = filelist_path 22 | self.hop_length = hop_length 23 | 24 | self._load_filelist(filelist_path) 25 | 26 | def _load_filelist(self, filelist_path): 27 | filelist, lengths = [], [] 28 | with open(filelist_path, 'r', encoding='utf-8') as f: 29 | for line in f: 30 | line = json.loads(line.strip()) 31 | filelist.append((line['mel_path'], line['phone'])) 32 | lengths.append(line['mel_length']) 33 | 34 | self.filelist = filelist 35 | self.lengths = lengths # length is used for DistributedBucketSampler 36 | 37 | def __len__(self): 38 | return len(self.filelist) 39 | 40 | def __getitem__(self, idx): 41 | mel_path, phone = self.filelist[idx] 42 | mel = torch.load(mel_path, map_location='cpu', weights_only=True) 43 | phone = torch.tensor(intersperse(cleaned_text_to_sequence(phone), 0), dtype=torch.long) 44 | return mel, phone 45 | 46 | def collate_fn(batch): 47 | texts = [item[1] for item in batch] 48 | mels = [item[0] for item in batch] 49 | mels_sliced = [random_slice_tensor(mel) for mel in mels] 50 | 51 | text_lengths = torch.tensor([text.size(-1) for text in texts], dtype=torch.long) 52 | mel_lengths = torch.tensor([mel.size(-1) for mel in mels], dtype=torch.long) 53 | mels_sliced_lengths = torch.tensor([mel_sliced.size(-1) for mel_sliced in mels_sliced], dtype=torch.long) 54 | 55 | # pad to the same length 56 | texts_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(texts), padding=0) 57 | mels_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels), padding=0) 58 | mels_sliced_padded = torch.nested.to_padded_tensor(torch.nested.nested_tensor(mels_sliced), padding=0) 59 | 60 | return texts_padded, text_lengths, mels_padded, mel_lengths, mels_sliced_padded, mels_sliced_lengths 61 | 62 | # random slice mel for reference encoder to prevent overfitting 63 | def random_slice_tensor(x: torch.Tensor): 64 | length = x.size(-1) 65 | if length < 12: 66 | return x 67 | segmnt_size = random.randint(length // 12, length // 3) 68 | start = random.randint(0, length - segmnt_size) 69 | return x[..., start : start + segmnt_size] -------------------------------------------------------------------------------- /datas/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # reference: https://github.com/jaywalnut310/vits/blob/main/data_utils.py 4 | class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): 5 | """ 6 | Maintain similar input lengths in a batch. 7 | Length groups are specified by boundaries. 8 | Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. 9 | 10 | It removes samples which are not included in the boundaries. 11 | Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | dataset, 17 | batch_size, 18 | boundaries, 19 | num_replicas=None, 20 | rank=None, 21 | shuffle=True, 22 | ): 23 | super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) 24 | self.lengths = dataset.lengths 25 | self.batch_size = batch_size 26 | self.boundaries = boundaries 27 | 28 | self.buckets, self.num_samples_per_bucket = self._create_buckets() 29 | self.total_size = sum(self.num_samples_per_bucket) 30 | self.num_samples = self.total_size // self.num_replicas 31 | 32 | def _create_buckets(self): 33 | buckets = [[] for _ in range(len(self.boundaries) - 1)] 34 | for i in range(len(self.lengths)): 35 | length = self.lengths[i] 36 | idx_bucket = self._bisect(length) 37 | if idx_bucket != -1: 38 | buckets[idx_bucket].append(i) 39 | 40 | # from https://github.com/Plachtaa/VITS-fast-fine-tuning/blob/main/data_utils.py 41 | # avoid "integer division or modulo by zero" error for very small dataset 42 | # see https://github.com/Plachtaa/VITS-fast-fine-tuning/pull/228 for more details 43 | try: 44 | for i in range(len(buckets) - 1, 0, -1): 45 | if len(buckets[i]) == 0: 46 | buckets.pop(i) 47 | self.boundaries.pop(i + 1) 48 | assert all(len(bucket) > 0 for bucket in buckets) 49 | # When one bucket is not traversed 50 | except Exception as e: 51 | print('Bucket warning ', e) 52 | for i in range(len(buckets) - 1, -1, -1): 53 | if len(buckets[i]) == 0: 54 | buckets.pop(i) 55 | self.boundaries.pop(i + 1) 56 | 57 | num_samples_per_bucket = [] 58 | for i in range(len(buckets)): 59 | len_bucket = len(buckets[i]) 60 | total_batch_size = self.num_replicas * self.batch_size 61 | rem = ( 62 | total_batch_size - (len_bucket % total_batch_size) 63 | ) % total_batch_size 64 | num_samples_per_bucket.append(len_bucket + rem) 65 | return buckets, num_samples_per_bucket 66 | 67 | def __iter__(self): 68 | # deterministically shuffle based on epoch 69 | g = torch.Generator() 70 | g.manual_seed(self.epoch) 71 | 72 | indices = [] 73 | if self.shuffle: 74 | for bucket in self.buckets: 75 | indices.append(torch.randperm(len(bucket), generator=g).tolist()) 76 | else: 77 | for bucket in self.buckets: 78 | indices.append(list(range(len(bucket)))) 79 | 80 | batches = [] 81 | for i in range(len(self.buckets)): 82 | bucket = self.buckets[i] 83 | len_bucket = len(bucket) 84 | ids_bucket = indices[i] 85 | num_samples_bucket = self.num_samples_per_bucket[i] 86 | 87 | # add extra samples to make it evenly divisible 88 | rem = num_samples_bucket - len_bucket 89 | ids_bucket = ( 90 | ids_bucket 91 | + ids_bucket * (rem // len_bucket) 92 | + ids_bucket[: (rem % len_bucket)] 93 | ) 94 | 95 | # subsample 96 | ids_bucket = ids_bucket[self.rank :: self.num_replicas] 97 | 98 | # batching 99 | for j in range(len(ids_bucket) // self.batch_size): 100 | batch = [ 101 | bucket[idx] 102 | for idx in ids_bucket[ 103 | j * self.batch_size : (j + 1) * self.batch_size 104 | ] 105 | ] 106 | batches.append(batch) 107 | 108 | if self.shuffle: 109 | batch_ids = torch.randperm(len(batches), generator=g).tolist() 110 | batches = [batches[i] for i in batch_ids] 111 | self.batches = batches 112 | 113 | assert len(self.batches) * self.batch_size == self.num_samples 114 | return iter(self.batches) 115 | 116 | def _bisect(self, x, lo=0, hi=None): 117 | if hi is None: 118 | hi = len(self.boundaries) - 1 119 | 120 | if hi > lo: 121 | mid = (hi + lo) // 2 122 | if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: 123 | return mid 124 | elif x <= self.boundaries[mid]: 125 | return self._bisect(x, lo, mid) 126 | else: 127 | return self._bisect(x, mid + 1, hi) 128 | else: 129 | return -1 130 | 131 | def __len__(self): 132 | return self.num_samples // self.batch_size -------------------------------------------------------------------------------- /figures/structure.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/figures/structure.jpg -------------------------------------------------------------------------------- /filelists/example.txt: -------------------------------------------------------------------------------- 1 | ./audio1.wav|你好,世界。 2 | ./audio2.wav|Hello, world. -------------------------------------------------------------------------------- /inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "from IPython.display import Audio, display\n", 10 | "import torch\n", 11 | "\n", 12 | "from api import StableTTSAPI\n", 13 | "\n", 14 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 15 | "\n", 16 | "tts_model_path = './checkpoints/checkpoint_0.pt' # path to StableTTS checkpoint\n", 17 | "vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' # path to vocoder checkpoint\n", 18 | "vocoder_type = 'ffgan' # ffgan or vocos\n", 19 | "\n", 20 | "# vocoder_model_path = './vocoders/pretrained/vocos.pt'\n", 21 | "# vocoder_type = 'vocos'\n", 22 | "\n", 23 | "model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type)\n", 24 | "model.to(device)\n", 25 | "\n", 26 | "tts_param, vocoder_param = model.get_params()\n", 27 | "print(f'tts_param: {tts_param}, vocoder_param: {vocoder_param}')" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "text = '你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!'\n", 37 | "ref_audio = './audio_1.wav'\n", 38 | "language = 'chinese' # support chinese, japanese and english\n", 39 | "solver = 'dopri5' # recommend using euler, midpoint or dopri5\n", 40 | "steps = 30\n", 41 | "cfg = 3 # recommend 1-4\n", 42 | "\n", 43 | "audio_output, mel_output = model.inference(text, ref_audio, language, steps, 1, 1, solver, cfg)\n", 44 | "\n", 45 | "display(Audio(ref_audio))\n", 46 | "display(Audio(audio_output, rate=model.mel_config.sample_rate))" 47 | ] 48 | } 49 | ], 50 | "metadata": { 51 | "kernelspec": { 52 | "display_name": "lxn_vits", 53 | "language": "python", 54 | "name": "python3" 55 | }, 56 | "language_info": { 57 | "codemirror_mode": { 58 | "name": "ipython", 59 | "version": 3 60 | }, 61 | "file_extension": ".py", 62 | "mimetype": "text/x-python", 63 | "name": "python", 64 | "nbconvert_exporter": "python", 65 | "pygments_lexer": "ipython3", 66 | "version": "3.11.8" 67 | } 68 | }, 69 | "nbformat": 4, 70 | "nbformat_minor": 2 71 | } 72 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/models/__init__.py -------------------------------------------------------------------------------- /models/diffusion_transformer.py: -------------------------------------------------------------------------------- 1 | # References: 2 | # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py 3 | # https://github.com/jaywalnut310/vits/blob/main/attentions.py 4 | # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class FFN(nn.Module): 11 | def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): 12 | super().__init__() 13 | self.in_channels = in_channels 14 | self.out_channels = out_channels 15 | self.filter_channels = filter_channels 16 | self.kernel_size = kernel_size 17 | self.p_dropout = p_dropout 18 | self.gin_channels = gin_channels 19 | 20 | self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) 21 | self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) 22 | self.drop = nn.Dropout(p_dropout) 23 | self.act1 = nn.SiLU(inplace=True) 24 | 25 | def forward(self, x, x_mask): 26 | x = self.conv_1(x * x_mask) 27 | x = self.act1(x) 28 | x = self.drop(x) 29 | x = self.conv_2(x * x_mask) 30 | return x * x_mask 31 | 32 | class MultiHeadAttention(nn.Module): 33 | def __init__(self, channels, out_channels, n_heads, p_dropout=0.): 34 | super().__init__() 35 | assert channels % n_heads == 0 36 | 37 | self.channels = channels 38 | self.out_channels = out_channels 39 | self.n_heads = n_heads 40 | self.p_dropout = p_dropout 41 | 42 | self.k_channels = channels // n_heads 43 | self.conv_q = torch.nn.Conv1d(channels, channels, 1) 44 | self.conv_k = torch.nn.Conv1d(channels, channels, 1) 45 | self.conv_v = torch.nn.Conv1d(channels, channels, 1) 46 | 47 | # from https://nn.labml.ai/transformers/rope/index.html 48 | self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) 49 | self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) 50 | 51 | self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) 52 | self.drop = torch.nn.Dropout(p_dropout) 53 | 54 | torch.nn.init.xavier_uniform_(self.conv_q.weight) 55 | torch.nn.init.xavier_uniform_(self.conv_k.weight) 56 | torch.nn.init.xavier_uniform_(self.conv_v.weight) 57 | 58 | def forward(self, x, attn_mask=None): 59 | q = self.conv_q(x) 60 | k = self.conv_k(x) 61 | v = self.conv_v(x) 62 | 63 | x = self.attention(q, k, v, mask=attn_mask) 64 | 65 | x = self.conv_o(x) 66 | return x 67 | 68 | def attention(self, query, key, value, mask=None): 69 | b, d, t_s, t_t = (*key.size(), query.size(2)) 70 | query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) 71 | key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 72 | value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) 73 | 74 | query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head] 75 | key = self.key_rotary_pe(key) 76 | 77 | output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) 78 | output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] 79 | return output 80 | 81 | # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 82 | class DiTConVBlock(nn.Module): 83 | """ 84 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 85 | """ 86 | def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): 87 | super().__init__() 88 | self.norm1 = nn.LayerNorm(hidden_channels, elementwise_affine=False) 89 | self.attn = MultiHeadAttention(hidden_channels, hidden_channels, num_heads, p_dropout) 90 | self.norm2 = nn.LayerNorm(hidden_channels, elementwise_affine=False) 91 | self.mlp = FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout) 92 | self.adaLN_modulation = nn.Sequential( 93 | nn.Linear(gin_channels, hidden_channels) if gin_channels != hidden_channels else nn.Identity(), 94 | nn.SiLU(), 95 | nn.Linear(hidden_channels, 6 * hidden_channels, bias=True) 96 | ) 97 | 98 | def forward(self, x, c, x_mask): 99 | """ 100 | Args: 101 | x : [batch_size, channel, time] 102 | c : [batch_size, channel] 103 | x_mask : [batch_size, 1, time] 104 | return the same shape as x 105 | """ 106 | x = x * x_mask 107 | attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time] 108 | attn_mask = torch.zeros_like(attn_mask).masked_fill(attn_mask == 0, -torch.finfo(x.dtype).max) 109 | 110 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1] 111 | x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1,2)).transpose(1,2), shift_msa, scale_msa), attn_mask) * x_mask 112 | x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1,2)).transpose(1,2), shift_mlp, scale_mlp), x_mask) 113 | 114 | # no condition version 115 | # x = x + self.attn(self.norm1(x.transpose(1,2)).transpose(1,2), attn_mask) 116 | # x = x + self.mlp(self.norm2(x.transpose(1,2)).transpose(1,2), x_mask) 117 | return x 118 | 119 | @staticmethod 120 | def modulate(x, shift, scale): 121 | return x * (1 + scale) + shift 122 | 123 | class RotaryPositionalEmbeddings(nn.Module): 124 | """ 125 | ## RoPE module 126 | 127 | Rotary encoding transforms pairs of features by rotating in the 2D plane. 128 | That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. 129 | Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it 130 | by an angle depending on the position of the token. 131 | """ 132 | 133 | def __init__(self, d: int, base: int = 10_000): 134 | r""" 135 | * `d` is the number of features $d$ 136 | * `base` is the constant used for calculating $\Theta$ 137 | """ 138 | super().__init__() 139 | 140 | self.base = base 141 | self.d = int(d) 142 | self.cos_cached = None 143 | self.sin_cached = None 144 | 145 | def _build_cache(self, x: torch.Tensor): 146 | r""" 147 | Cache $\cos$ and $\sin$ values 148 | """ 149 | # Return if cache is already built 150 | if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: 151 | return 152 | 153 | # Get sequence length 154 | seq_len = x.shape[0] 155 | 156 | # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ 157 | theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) 158 | 159 | # Create position indexes `[0, 1, ..., seq_len - 1]` 160 | seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) 161 | 162 | # Calculate the product of position index and $\theta_i$ 163 | idx_theta = torch.einsum("n,d->nd", seq_idx, theta) 164 | 165 | # Concatenate so that for row $m$ we have 166 | # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ 167 | idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) 168 | 169 | # Cache them 170 | self.cos_cached = idx_theta2.cos()[:, None, None, :] 171 | self.sin_cached = idx_theta2.sin()[:, None, None, :] 172 | 173 | def _neg_half(self, x: torch.Tensor): 174 | # $\frac{d}{2}$ 175 | d_2 = self.d // 2 176 | 177 | # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ 178 | return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) 179 | 180 | def forward(self, x: torch.Tensor): 181 | """ 182 | * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` 183 | """ 184 | # Cache $\cos$ and $\sin$ values 185 | x = x.permute(2, 0, 1, 3) # b h t d -> t b h d 186 | 187 | self._build_cache(x) 188 | 189 | # Split the features, we can choose to apply rotary embeddings only to a partial set of features. 190 | x_rope, x_pass = x[..., : self.d], x[..., self.d :] 191 | 192 | # Calculate 193 | # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ 194 | neg_half_x = self._neg_half(x_rope) 195 | 196 | x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) 197 | 198 | return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d 199 | 200 | class Transpose(nn.Identity): 201 | """(N, T, D) -> (N, D, T)""" 202 | 203 | def forward(self, input: torch.Tensor) -> torch.Tensor: 204 | return input.transpose(1, 2) 205 | 206 | -------------------------------------------------------------------------------- /models/duration_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # modified from https://github.com/jaywalnut310/vits/blob/main/models.py#L98 5 | class DurationPredictor(nn.Module): 6 | def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): 7 | super().__init__() 8 | 9 | self.in_channels = in_channels 10 | self.filter_channels = filter_channels 11 | self.kernel_size = kernel_size 12 | self.p_dropout = p_dropout 13 | self.gin_channels = gin_channels 14 | 15 | self.drop = nn.Dropout(p_dropout) 16 | self.conv1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) 17 | self.norm1 = nn.LayerNorm(filter_channels) 18 | self.conv2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) 19 | self.norm2 = nn.LayerNorm(filter_channels) 20 | self.proj = nn.Conv1d(filter_channels, 1, 1) 21 | 22 | self.cond = nn.Conv1d(gin_channels, in_channels, 1) 23 | 24 | def forward(self, x, x_mask, g): 25 | x = x.detach() 26 | x = x + self.cond(g.unsqueeze(2).detach()) 27 | x = self.conv1(x * x_mask) 28 | x = torch.relu(x) 29 | x = self.norm1(x.transpose(1,2)).transpose(1,2) 30 | x = self.drop(x) 31 | x = self.conv2(x * x_mask) 32 | x = torch.relu(x) 33 | x = self.norm2(x.transpose(1,2)).transpose(1,2) 34 | x = self.drop(x) 35 | x = self.proj(x * x_mask) 36 | return x * x_mask 37 | 38 | def duration_loss(logw, logw_, lengths): 39 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 40 | return loss -------------------------------------------------------------------------------- /models/estimator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from models.diffusion_transformer import DiTConVBlock 7 | 8 | class DitWrapper(nn.Module): 9 | """ add FiLM layer to condition time embedding to DiT """ 10 | def __init__(self, hidden_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0, time_channels=0): 11 | super().__init__() 12 | self.time_fusion = FiLMLayer(hidden_channels, time_channels) 13 | self.block = DiTConVBlock(hidden_channels, filter_channels, num_heads, kernel_size, p_dropout, gin_channels) 14 | 15 | def forward(self, x, c, t, x_mask): 16 | x = self.time_fusion(x, t) * x_mask 17 | x = self.block(x, c, x_mask) 18 | return x 19 | 20 | class FiLMLayer(nn.Module): 21 | """ 22 | Feature-wise Linear Modulation (FiLM) layer 23 | Reference: https://arxiv.org/abs/1709.07871 24 | """ 25 | def __init__(self, in_channels, cond_channels): 26 | 27 | super(FiLMLayer, self).__init__() 28 | self.in_channels = in_channels 29 | self.film = nn.Conv1d(cond_channels, in_channels * 2, 1) 30 | 31 | def forward(self, x, c): 32 | gamma, beta = torch.chunk(self.film(c.unsqueeze(2)), chunks=2, dim=1) 33 | return gamma * x + beta 34 | 35 | class SinusoidalPosEmb(nn.Module): 36 | def __init__(self, dim): 37 | super().__init__() 38 | self.dim = dim 39 | assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even" 40 | 41 | def forward(self, x, scale=1000): 42 | if x.ndim < 1: 43 | x = x.unsqueeze(0) 44 | half_dim = self.dim // 2 45 | emb = math.log(10000) / (half_dim - 1) 46 | emb = torch.exp(torch.arange(half_dim, device=x.device).float() * -emb) 47 | emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) 48 | emb = torch.cat((emb.sin(), emb.cos()), dim=-1) 49 | return emb 50 | 51 | class TimestepEmbedding(nn.Module): 52 | def __init__(self, in_channels, out_channels, filter_channels): 53 | super().__init__() 54 | 55 | self.layer = nn.Sequential( 56 | nn.Linear(in_channels, filter_channels), 57 | nn.SiLU(inplace=True), 58 | nn.Linear(filter_channels, out_channels) 59 | ) 60 | 61 | def forward(self, x): 62 | return self.layer(x) 63 | 64 | # reference: https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/decoder.py 65 | class Decoder(nn.Module): 66 | def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, dropout=0.1, n_layers=1, n_heads=4, kernel_size=3, gin_channels=0, use_lsc=True): 67 | super().__init__() 68 | self.noise_channels = noise_channels 69 | self.cond_channels = cond_channels 70 | self.hidden_channels = hidden_channels 71 | self.out_channels = out_channels 72 | self.filter_channels = filter_channels 73 | self.use_lsc = use_lsc # whether to use unet-like long skip connection 74 | 75 | self.time_embeddings = SinusoidalPosEmb(hidden_channels) 76 | self.time_mlp = TimestepEmbedding(hidden_channels, hidden_channels, filter_channels) 77 | 78 | self.in_proj = nn.Conv1d(hidden_channels + noise_channels, hidden_channels, 1) # cat noise and encoder output as input 79 | self.blocks = nn.ModuleList([DitWrapper(hidden_channels, filter_channels, n_heads, kernel_size, dropout, gin_channels, hidden_channels) for _ in range(n_layers)]) 80 | self.final_proj = nn.Conv1d(hidden_channels, out_channels, 1) 81 | 82 | # prenet for encoder output 83 | self.cond_proj = nn.Sequential( 84 | nn.Conv1d(cond_channels, filter_channels, kernel_size, padding=kernel_size//2), 85 | nn.SiLU(inplace=True), 86 | nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2), # add about 3M params 87 | nn.SiLU(inplace=True), 88 | nn.Conv1d(filter_channels, hidden_channels, kernel_size, padding=kernel_size//2) 89 | ) 90 | 91 | if use_lsc: 92 | assert n_layers % 2 == 0 93 | self.n_lsc_layers = n_layers // 2 94 | self.lsc_layers = nn.ModuleList([nn.Conv1d(hidden_channels + hidden_channels, hidden_channels, kernel_size, padding = kernel_size // 2) for _ in range(self.n_lsc_layers)]) 95 | 96 | self.initialize_weights() 97 | 98 | def initialize_weights(self): 99 | for block in self.blocks: 100 | nn.init.constant_(block.block.adaLN_modulation[-1].weight, 0) 101 | nn.init.constant_(block.block.adaLN_modulation[-1].bias, 0) 102 | 103 | def forward(self, t, x, mask, mu, c): 104 | """Forward pass of the DiT model. 105 | 106 | Args: 107 | t (torch.Tensor): timestep, shape (batch_size) 108 | x (torch.Tensor): noise, shape (batch_size, in_channels, time) 109 | mask (torch.Tensor): shape (batch_size, 1, time) 110 | mu (torch.Tensor): output of encoder, shape (batch_size, in_channels, time) 111 | c (torch.Tensor): shape (batch_size, gin_channels) 112 | 113 | Returns: 114 | _type_: _description_ 115 | """ 116 | 117 | t = self.time_mlp(self.time_embeddings(t)) 118 | mu = self.cond_proj(mu) 119 | 120 | x = torch.cat((x, mu), dim=1) 121 | x = self.in_proj(x) 122 | 123 | lsc_outputs = [] if self.use_lsc else None 124 | 125 | for idx, block in enumerate(self.blocks): 126 | # add long skip connection, see https://arxiv.org/pdf/2209.12152 for more details 127 | if self.use_lsc: 128 | if idx < self.n_lsc_layers: 129 | lsc_outputs.append(x) 130 | else: 131 | x = torch.cat((x, lsc_outputs.pop()), dim=1) 132 | x = self.lsc_layers[idx - self.n_lsc_layers](x) 133 | 134 | x = block(x, c, t, mask) 135 | 136 | output = self.final_proj(x * mask) 137 | 138 | return output * mask -------------------------------------------------------------------------------- /models/flow_matching.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import functools 6 | from torchdiffeq import odeint 7 | 8 | from models.estimator import Decoder 9 | 10 | # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/flow_matching.py 11 | class CFMDecoder(torch.nn.Module): 12 | def __init__(self, noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels): 13 | super().__init__() 14 | self.noise_channels = noise_channels 15 | self.cond_channels = cond_channels 16 | self.hidden_channels = hidden_channels 17 | self.out_channels = out_channels 18 | self.filter_channels = filter_channels 19 | self.gin_channels = gin_channels 20 | self.sigma_min = 1e-4 21 | 22 | self.estimator = Decoder(noise_channels, cond_channels, hidden_channels, out_channels, filter_channels, p_dropout, n_layers, n_heads, kernel_size, gin_channels) 23 | 24 | @torch.inference_mode() 25 | def forward(self, mu, mask, n_timesteps, temperature=1.0, c=None, solver=None, cfg_kwargs=None): 26 | """Forward diffusion 27 | 28 | Args: 29 | mu (torch.Tensor): output of encoder 30 | shape: (batch_size, n_feats, mel_timesteps) 31 | mask (torch.Tensor): output_mask 32 | shape: (batch_size, 1, mel_timesteps) 33 | n_timesteps (int): number of diffusion steps 34 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 35 | c (torch.Tensor, optional): speaker embedding 36 | shape: (batch_size, gin_channels) 37 | solver: see https://github.com/rtqichen/torchdiffeq for supported solvers 38 | cfg_kwargs: used for cfg inference 39 | 40 | Returns: 41 | sample: generated mel-spectrogram 42 | shape: (batch_size, n_feats, mel_timesteps) 43 | """ 44 | 45 | z = torch.randn_like(mu) * temperature 46 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 47 | 48 | # cfg control 49 | if cfg_kwargs is None: 50 | estimator = functools.partial(self.estimator, mask=mask, mu=mu, c=c) 51 | else: 52 | estimator = functools.partial(self.cfg_wrapper, mask=mask, mu=mu, c=c, cfg_kwargs=cfg_kwargs) 53 | 54 | trajectory = odeint(estimator, z, t_span, method=solver, rtol=1e-5, atol=1e-5) 55 | return trajectory[-1] 56 | 57 | # cfg inference 58 | def cfg_wrapper(self, t, x, mask, mu, c, cfg_kwargs): 59 | fake_speaker = cfg_kwargs['fake_speaker'].repeat(x.size(0), 1) 60 | fake_content = cfg_kwargs['fake_content'].repeat(x.size(0), 1, x.size(-1)) 61 | cfg_strength = cfg_kwargs['cfg_strength'] 62 | 63 | cond_output = self.estimator(t, x, mask, mu, c) 64 | uncond_output = self.estimator(t, x, mask, fake_content, fake_speaker) 65 | 66 | output = uncond_output + cfg_strength * (cond_output - uncond_output) 67 | return output 68 | 69 | def compute_loss(self, x1, mask, mu, c): 70 | """Computes diffusion loss 71 | 72 | Args: 73 | x1 (torch.Tensor): Target 74 | shape: (batch_size, n_feats, mel_timesteps) 75 | mask (torch.Tensor): target mask 76 | shape: (batch_size, 1, mel_timesteps) 77 | mu (torch.Tensor): output of encoder 78 | shape: (batch_size, n_feats, mel_timesteps) 79 | c (torch.Tensor, optional): speaker condition. 80 | 81 | Returns: 82 | loss: conditional flow matching loss 83 | y: conditional flow 84 | shape: (batch_size, n_feats, mel_timesteps) 85 | """ 86 | b, _, t = mu.shape 87 | 88 | # random timestep 89 | # use cosine timestep scheduler from cosyvoice: https://github.com/FunAudioLLM/CosyVoice/blob/main/cosyvoice/flow/flow_matching.py 90 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 91 | t = 1 - torch.cos(t * 0.5 * torch.pi) 92 | 93 | # sample noise p(x_0) 94 | z = torch.randn_like(x1) 95 | 96 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 97 | u = x1 - (1 - self.sigma_min) * z 98 | 99 | loss = F.mse_loss(self.estimator(t.squeeze(), y, mask, mu, c), u, reduction="sum") / (torch.sum(mask) * u.size(1)) 100 | return loss, y 101 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | import monotonic_align 6 | from models.text_encoder import TextEncoder 7 | from models.flow_matching import CFMDecoder 8 | from models.reference_encoder import MelStyleEncoder 9 | from models.duration_predictor import DurationPredictor, duration_loss 10 | from utils.mask import sequence_mask 11 | 12 | def convert_pad_shape(pad_shape): 13 | inverted_shape = pad_shape[::-1] 14 | pad_shape = [item for sublist in inverted_shape for item in sublist] 15 | return pad_shape 16 | 17 | def generate_path(duration, mask): 18 | b, t_x, t_y = mask.shape 19 | cum_duration = torch.cumsum(duration, 1) 20 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype, device=duration.device) 21 | 22 | cum_duration_flat = cum_duration.view(b * t_x) 23 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 24 | path = path.view(b, t_x, t_y) 25 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 26 | path = path * mask 27 | return path 28 | 29 | # modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py 30 | class StableTTS(nn.Module): 31 | def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels): 32 | super().__init__() 33 | 34 | self.n_vocab = n_vocab 35 | self.mel_channels = mel_channels 36 | 37 | self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels) 38 | self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=5, dropout=0.25) 39 | self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, 0.5, gin_channels) 40 | self.decoder = CFMDecoder(mel_channels, mel_channels, hidden_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels) 41 | 42 | # uncondition input for cfg 43 | self.fake_speaker = nn.Parameter(torch.zeros(1, gin_channels)) 44 | self.fake_content = nn.Parameter(torch.zeros(1, mel_channels, 1)) 45 | 46 | self.cfg_dropout = 0.2 47 | 48 | @torch.inference_mode() 49 | def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0, solver=None, cfg=1.0): 50 | """ 51 | Generates mel-spectrogram from text. Returns: 52 | 1. encoder outputs 53 | 2. decoder outputs 54 | 3. generated alignment 55 | 56 | Args: 57 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 58 | shape: (batch_size, max_text_length) 59 | x_lengths (torch.Tensor): lengths of texts in batch. 60 | shape: (batch_size,) 61 | n_timesteps (int): number of steps to use for reverse diffusion in decoder. 62 | temperature (float, optional): controls variance of terminal distribution. 63 | y (torch.Tensor): mel spectrogram of reference audio 64 | shape: (batch_size, mel_channels, time) 65 | length_scale (float, optional): controls speech pace. 66 | Increase value to slow down generated speech and vice versa. 67 | 68 | Returns: 69 | dict: { 70 | "encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 71 | # Average mel spectrogram generated by the encoder 72 | "decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length), 73 | # Refined mel spectrogram improved by the CFM 74 | "attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length), 75 | # Alignment map between text and mel spectrogram 76 | """ 77 | 78 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 79 | c = self.ref_encoder(y, None) 80 | x, mu_x, x_mask = self.encoder(x, c, x_lengths) 81 | logw = self.dp(x, x_mask, c) 82 | 83 | w = torch.exp(logw) * x_mask 84 | w_ceil = torch.ceil(w) * length_scale 85 | y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() 86 | y_max_length = y_lengths.max() 87 | 88 | # Using obtained durations `w` construct alignment map `attn` 89 | y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype) 90 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 91 | attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1) 92 | 93 | # Align encoded text and get mu_y 94 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 95 | mu_y = mu_y.transpose(1, 2) 96 | encoder_outputs = mu_y[:, :, :y_max_length] 97 | 98 | # Generate sample tracing the probability flow 99 | if cfg == 1.0: 100 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver) 101 | else: 102 | cfg_kwargs = {'fake_speaker': self.fake_speaker, 'fake_content': self.fake_content, 'cfg_strength': cfg} 103 | decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver, cfg_kwargs) 104 | 105 | decoder_outputs = decoder_outputs[:, :, :y_max_length] 106 | 107 | 108 | return { 109 | "encoder_outputs": encoder_outputs, 110 | "decoder_outputs": decoder_outputs, 111 | "attn": attn[:, :, :y_max_length], 112 | } 113 | 114 | def forward(self, x, x_lengths, y, y_lengths, z, z_lengths): 115 | """ 116 | Computes 3 losses: 117 | 1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS). 118 | 2. prior loss: loss between mel-spectrogram and encoder outputs. 119 | 3. flow matching loss: loss between mel-spectrogram and decoder outputs. 120 | 121 | Args: 122 | x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids. 123 | shape: (batch_size, max_text_length) 124 | x_lengths (torch.Tensor): lengths of texts in batch. 125 | shape: (batch_size,) 126 | y (torch.Tensor): batch of corresponding mel-spectrograms. 127 | shape: (batch_size, n_feats, max_mel_length) 128 | y_lengths (torch.Tensor): lengths of mel-spectrograms in batch. 129 | shape: (batch_size,) 130 | z (torch.Tensor): batch of cliced mel-spectrograms. 131 | shape: (batch_size, n_feats, max_mel_length) 132 | z_lengths (torch.Tensor): lengths of sliced mel-spectrograms in batch. 133 | shape: (batch_size,) 134 | """ 135 | # Get encoder_outputs `mu_x` and log-scaled token durations `logw` 136 | y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype) 137 | z_mask = sequence_mask(z_lengths, z.size(2)).unsqueeze(1).to(z.dtype) 138 | cfg_mask = torch.rand(y.size(0), 1, device=y.device) > self.cfg_dropout 139 | 140 | # compute global speaker embedding 141 | c = self.ref_encoder(z, z_mask) * cfg_mask + ~cfg_mask * self.fake_speaker.repeat(z.size(0), 1) 142 | 143 | x, mu_x, x_mask = self.encoder(x, c, x_lengths) 144 | logw = self.dp(x, x_mask, c) 145 | 146 | attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2) 147 | 148 | # Use MAS to find most likely alignment `attn` between text and mel-spectrogram 149 | with torch.no_grad(): 150 | s_p_sq_r = torch.ones_like(mu_x) # [b, d, t] 151 | neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True) 152 | neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r) 153 | neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r)) 154 | neg_cent4 = torch.sum(-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True) 155 | neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 156 | 157 | attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) 158 | attn = (monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach()) 159 | 160 | # Compute loss between predicted log-scaled durations and those obtained from MAS 161 | # refered to as prior loss in the paper 162 | logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask 163 | dur_loss = duration_loss(logw, logw_, x_lengths) 164 | 165 | # Align encoded text with mel-spectrogram and get mu_y segment 166 | attn = attn.squeeze(1).transpose(1,2) 167 | mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2)) 168 | mu_y = mu_y.transpose(1, 2) 169 | 170 | # Compute loss of the decoder 171 | cfg_mask = cfg_mask.unsqueeze(-1) 172 | mu_y_masked = mu_y * cfg_mask + ~cfg_mask * self.fake_content.repeat(mu_y.size(0), 1, mu_y.size(-1)) # mask content information for better diversity for flow-matching 173 | diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y_masked, c) 174 | 175 | prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask) 176 | prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels) 177 | 178 | return dur_loss, diff_loss, prior_loss, attn -------------------------------------------------------------------------------- /models/reference_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Conv1dGLU(nn.Module): 5 | """ 6 | Conv1d + GLU(Gated Linear Unit) with residual connection. 7 | For GLU refer to https://arxiv.org/abs/1612.08083 paper. 8 | """ 9 | 10 | def __init__(self, in_channels, out_channels, kernel_size, dropout): 11 | super(Conv1dGLU, self).__init__() 12 | self.out_channels = out_channels 13 | self.conv1 = nn.Conv1d(in_channels, 2 * out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 14 | self.dropout = nn.Dropout(dropout) 15 | 16 | def forward(self, x): 17 | residual = x 18 | x = self.conv1(x) 19 | x1, x2 = torch.split(x, self.out_channels, dim=1) 20 | x = x1 * torch.sigmoid(x2) 21 | x = residual + self.dropout(x) 22 | return x 23 | 24 | # modified from https://github.com/RVC-Boss/GPT-SoVITS/blob/main/GPT_SoVITS/module/modules.py#L766 25 | class MelStyleEncoder(nn.Module): 26 | """MelStyleEncoder""" 27 | 28 | def __init__( 29 | self, 30 | n_mel_channels=80, 31 | style_hidden=128, 32 | style_vector_dim=256, 33 | style_kernel_size=5, 34 | style_head=2, 35 | dropout=0.1, 36 | ): 37 | super(MelStyleEncoder, self).__init__() 38 | self.in_dim = n_mel_channels 39 | self.hidden_dim = style_hidden 40 | self.out_dim = style_vector_dim 41 | self.kernel_size = style_kernel_size 42 | self.n_head = style_head 43 | self.dropout = dropout 44 | 45 | self.spectral = nn.Sequential( 46 | nn.Linear(self.in_dim, self.hidden_dim), 47 | nn.Mish(inplace=True), 48 | nn.Dropout(self.dropout), 49 | nn.Linear(self.hidden_dim, self.hidden_dim), 50 | nn.Mish(inplace=True), 51 | nn.Dropout(self.dropout), 52 | ) 53 | 54 | self.temporal = nn.Sequential( 55 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 56 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 57 | ) 58 | 59 | self.slf_attn = nn.MultiheadAttention( 60 | self.hidden_dim, 61 | self.n_head, 62 | self.dropout, 63 | batch_first=True 64 | ) 65 | 66 | self.fc = nn.Linear(self.hidden_dim, self.out_dim) 67 | 68 | def temporal_avg_pool(self, x, mask=None): 69 | if mask is None: 70 | return torch.mean(x, dim=1) 71 | else: 72 | return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1) 73 | 74 | def forward(self, x, x_mask=None): 75 | x = x.transpose(1, 2) 76 | 77 | # spectral 78 | x = self.spectral(x) 79 | # temporal 80 | x = x.transpose(1, 2) 81 | x = self.temporal(x) 82 | x = x.transpose(1, 2) 83 | # self-attention 84 | if x_mask is not None: 85 | x_mask = ~x_mask.squeeze(1).to(torch.bool) 86 | x, _ = self.slf_attn(x, x, x, key_padding_mask=x_mask, need_weights=False) 87 | # fc 88 | x = self.fc(x) 89 | # temoral average pooling 90 | w = self.temporal_avg_pool(x, mask=x_mask) 91 | 92 | return w 93 | 94 | # Attention Pool version of MelStyleEncoder, not used 95 | class AttnMelStyleEncoder(nn.Module): 96 | """MelStyleEncoder""" 97 | 98 | def __init__( 99 | self, 100 | n_mel_channels=80, 101 | style_hidden=128, 102 | style_vector_dim=256, 103 | style_kernel_size=5, 104 | style_head=2, 105 | dropout=0.1, 106 | ): 107 | super().__init__() 108 | self.in_dim = n_mel_channels 109 | self.hidden_dim = style_hidden 110 | self.out_dim = style_vector_dim 111 | self.kernel_size = style_kernel_size 112 | self.n_head = style_head 113 | self.dropout = dropout 114 | 115 | self.spectral = nn.Sequential( 116 | nn.Linear(self.in_dim, self.hidden_dim), 117 | nn.Mish(inplace=True), 118 | nn.Dropout(self.dropout), 119 | nn.Linear(self.hidden_dim, self.hidden_dim), 120 | nn.Mish(inplace=True), 121 | nn.Dropout(self.dropout), 122 | ) 123 | 124 | self.temporal = nn.Sequential( 125 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 126 | Conv1dGLU(self.hidden_dim, self.hidden_dim, self.kernel_size, self.dropout), 127 | ) 128 | 129 | self.slf_attn = nn.MultiheadAttention( 130 | self.hidden_dim, 131 | self.n_head, 132 | self.dropout, 133 | batch_first=True 134 | ) 135 | 136 | self.fc = nn.Linear(self.hidden_dim, self.out_dim) 137 | 138 | def temporal_avg_pool(self, x, mask=None): 139 | if mask is None: 140 | return torch.mean(x, dim=1) 141 | else: 142 | return torch.sum(x * ~mask.unsqueeze(-1), dim=1) / (~mask).sum(dim=1).unsqueeze(1) 143 | 144 | def forward(self, x, x_mask=None): 145 | x = x.transpose(1, 2) 146 | 147 | # spectral 148 | x = self.spectral(x) 149 | # temporal 150 | x = x.transpose(1, 2) 151 | x = self.temporal(x) 152 | x = x.transpose(1, 2) 153 | # self-attention 154 | if x_mask is not None: 155 | x_mask = ~x_mask.squeeze(1).to(torch.bool) 156 | zeros = torch.zeros(x_mask.size(0), 1, device=x_mask.device, dtype=x_mask.dtype) 157 | x_attn_mask = torch.cat((zeros, x_mask), dim=1) 158 | else: 159 | x_attn_mask = None 160 | 161 | avg = self.temporal_avg_pool(x, x_mask).unsqueeze(1) 162 | x = torch.cat([avg, x], dim=1) 163 | x, _ = self.slf_attn(x, x, x, key_padding_mask=x_attn_mask, need_weights=False) 164 | x = x[:, 0, :] 165 | # fc 166 | x = self.fc(x) 167 | 168 | return x -------------------------------------------------------------------------------- /models/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.diffusion_transformer import DiTConVBlock 5 | from utils.mask import sequence_mask 6 | 7 | # modified from https://github.com/jaywalnut310/vits/blob/main/models.py 8 | class TextEncoder(nn.Module): 9 | def __init__(self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, gin_channels): 10 | super().__init__() 11 | self.n_vocab = n_vocab 12 | self.out_channels = out_channels 13 | self.hidden_channels = hidden_channels 14 | self.filter_channels = filter_channels 15 | self.n_heads = n_heads 16 | self.n_layers = n_layers 17 | self.kernel_size = kernel_size 18 | self.p_dropout = p_dropout 19 | 20 | self.scale = self.hidden_channels ** 0.5 21 | 22 | self.emb = nn.Embedding(n_vocab, hidden_channels) 23 | nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) 24 | 25 | self.encoder = nn.ModuleList([DiTConVBlock(hidden_channels, filter_channels, n_heads, kernel_size, p_dropout, gin_channels) for _ in range(n_layers)]) 26 | self.proj = nn.Conv1d(hidden_channels, out_channels, 1) 27 | 28 | self.initialize_weights() 29 | 30 | def initialize_weights(self): 31 | for block in self.encoder: 32 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 33 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 34 | 35 | def forward(self, x: torch.Tensor, c: torch.Tensor, x_lengths: torch.Tensor): 36 | x = self.emb(x) * self.scale # [b, t, h] 37 | x = x.transpose(1, -1) # [b, h, t] 38 | x_mask = sequence_mask(x_lengths, x.size(2)).unsqueeze(1).to(x.dtype) 39 | 40 | for layer in self.encoder: 41 | x = layer(x, c, x_mask) 42 | mu_x = self.proj(x) * x_mask 43 | 44 | return x, mu_x, x_mask 45 | -------------------------------------------------------------------------------- /monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | from numpy import zeros, int32, float32 2 | from torch import from_numpy 3 | 4 | from .core import maximum_path_jit 5 | 6 | 7 | def maximum_path(neg_cent, mask): 8 | device = neg_cent.device 9 | dtype = neg_cent.dtype 10 | neg_cent = neg_cent.data.cpu().numpy().astype(float32) 11 | path = zeros(neg_cent.shape, dtype=int32) 12 | 13 | t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32) 14 | t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32) 15 | maximum_path_jit(path, neg_cent, t_t_max, t_s_max) 16 | return from_numpy(path).to(device=device, dtype=dtype) 17 | -------------------------------------------------------------------------------- /monotonic_align/core.py: -------------------------------------------------------------------------------- 1 | import numba 2 | 3 | 4 | @numba.jit( 5 | numba.void( 6 | numba.int32[:, :, ::1], 7 | numba.float32[:, :, ::1], 8 | numba.int32[::1], 9 | numba.int32[::1], 10 | ), 11 | nopython=True, 12 | nogil=True, 13 | ) 14 | def maximum_path_jit(paths, values, t_ys, t_xs): 15 | b = paths.shape[0] 16 | max_neg_val = -1e9 17 | for i in range(int(b)): 18 | path = paths[i] 19 | value = values[i] 20 | t_y = t_ys[i] 21 | t_x = t_xs[i] 22 | 23 | v_prev = v_cur = 0.0 24 | index = t_x - 1 25 | 26 | for y in range(t_y): 27 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 28 | if x == y: 29 | v_cur = max_neg_val 30 | else: 31 | v_cur = value[y - 1, x] 32 | if x == 0: 33 | if y == 0: 34 | v_prev = 0.0 35 | else: 36 | v_prev = max_neg_val 37 | else: 38 | v_prev = value[y - 1, x - 1] 39 | value[y, x] += max(v_prev, v_cur) 40 | 41 | for y in range(t_y - 1, -1, -1): 42 | path[y, index] = 1 43 | if index != 0 and ( 44 | index == y or value[y - 1, index] < value[y - 1, index - 1] 45 | ): 46 | index = index - 1 47 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | from dataclasses import dataclass, asdict 5 | 6 | import torch 7 | from torch.multiprocessing import Pool, set_start_method 8 | import torchaudio 9 | 10 | from config import MelConfig, TrainConfig 11 | from utils.audio import LogMelSpectrogram, load_and_resample_audio 12 | 13 | from text.mandarin import chinese_to_cnm3 14 | from text.english import english_to_ipa2 15 | from text.japanese import japanese_to_ipa2 16 | 17 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | @dataclass 20 | class DataConfig: 21 | input_filelist_path = './filelists/filelist.txt' # a filelist contains 'audiopath | text' 22 | output_filelist_path = './filelists/filelist.json' # path to save filelist 23 | output_feature_path = './stableTTS_datasets' # path to save resampled audios and mel features 24 | language = 'english' # chinese, japanese or english 25 | resample = False # waveform is not used in training, so save resampled results is not necessary. 26 | 27 | g2p_mapping = { 28 | 'chinese': chinese_to_cnm3, 29 | 'japanese': japanese_to_ipa2, 30 | 'english': english_to_ipa2, 31 | } 32 | 33 | data_config = DataConfig() 34 | train_config = TrainConfig() 35 | mel_config = MelConfig() 36 | 37 | input_filelist_path = data_config.input_filelist_path 38 | output_filelist_path = data_config.output_filelist_path 39 | output_feature_path = data_config.output_feature_path 40 | 41 | # Ensure output directories exist 42 | output_mel_dir = os.path.join(output_feature_path, 'mels') 43 | os.makedirs(output_mel_dir, exist_ok=True) 44 | os.makedirs(os.path.dirname(output_filelist_path), exist_ok=True) 45 | 46 | if data_config.resample: 47 | output_wav_dir = os.path.join(output_feature_path, 'waves') 48 | os.makedirs(output_wav_dir, exist_ok=True) 49 | 50 | mel_extractor = LogMelSpectrogram(**asdict(mel_config)).to(device) 51 | 52 | g2p = g2p_mapping.get(data_config.language) 53 | 54 | def load_filelist(path) -> list: 55 | file_list = [] 56 | with open(path, 'r', encoding='utf-8') as f: 57 | for idx, line in enumerate(f): 58 | audio_path, text = line.strip().split('|', maxsplit=1) 59 | file_list.append((str(idx), audio_path, text)) 60 | return file_list 61 | 62 | @ torch.inference_mode() 63 | def process_filelist(line) -> str: 64 | idx, audio_path, text = line 65 | audio = load_and_resample_audio(audio_path, mel_config.sample_rate, device=device) # shape: [1, time] 66 | if audio is not None: 67 | # get output path 68 | audio_name, _ = os.path.splitext(os.path.basename(audio_path)) 69 | 70 | try: 71 | phone = g2p(text) 72 | if len(phone) > 0: 73 | mel = mel_extractor(audio.to(device)).cpu().squeeze(0) # shape: [n_mels, time // hop_length] 74 | output_mel_path = os.path.join(output_mel_dir, f'{idx}_{audio_name}.pt') 75 | torch.save(mel, output_mel_path) 76 | 77 | if data_config.resample: 78 | audio_path = os.path.join(output_wav_dir, f'{idx}_{audio_name}.wav') 79 | torchaudio.save(audio_path, audio.cpu(), mel_config.sample_rate) 80 | return json.dumps({'mel_path': output_mel_path, 'phone': phone, 'audio_path': audio_path, 'text': text, 'mel_length': mel.size(-1)}, ensure_ascii=False, allow_nan=False) 81 | except Exception as e: 82 | print(f'Error processing {audio_path}: {str(e)}') 83 | 84 | 85 | def main(): 86 | set_start_method('spawn') # CUDA must use spawn method 87 | input_filelist = load_filelist(input_filelist_path) 88 | results = [] 89 | 90 | with Pool(processes=2) as pool: 91 | for result in tqdm(pool.imap(process_filelist, input_filelist), total=len(input_filelist)): 92 | if result is not None: 93 | results.append(f'{result}\n') 94 | 95 | # save filelist 96 | with open(output_filelist_path, 'w', encoding='utf-8') as f: 97 | f.writelines(results) 98 | print(f"filelist file has been saved to {output_filelist_path}") 99 | 100 | # faster and use much less CPU 101 | torch.set_num_threads(1) 102 | torch.set_num_interop_threads(1) 103 | 104 | if __name__ == '__main__': 105 | main() -------------------------------------------------------------------------------- /recipes/AiSHELL3.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import re 4 | from dataclasses import dataclass 5 | import concurrent.futures 6 | 7 | from tqdm.auto import tqdm 8 | 9 | # download_link: https://www.openslr.org/93/ 10 | @dataclass 11 | class DataConfig: 12 | dataset_path = './raw_datasets/Aishell3/train/wav' 13 | txt_path = './raw_datasets/Aishell3/train/content.txt' 14 | output_filelist_path = './filelists/aishell3.txt' 15 | 16 | data_config = DataConfig() 17 | 18 | def process_filelist(line): 19 | dir_name, audio_path, text = line 20 | input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, dir_name, audio_path)) 21 | if os.path.exists(input_audio_path): 22 | return f'{input_audio_path}|{text}\n' 23 | 24 | if __name__ == '__main__': 25 | filelist = [] 26 | results = [] 27 | 28 | with open(data_config.txt_path, 'r', encoding='utf-8') as f: 29 | for idx, line in enumerate(f): 30 | audio_path, text = line.strip().split(maxsplit=1) 31 | dir_name = audio_path[:7] 32 | text = re.sub(r'[a-zA-Z0-9\s]', '', text) # remove pinyin and tone 33 | filelist.append((dir_name, audio_path, text)) 34 | 35 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: 36 | futures = [executor.submit(process_filelist, line) for line in filelist] 37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)): 38 | result = future.result() 39 | if result is not None: 40 | results.append(result) 41 | 42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 45 | f.writelines(results) -------------------------------------------------------------------------------- /recipes/BZNSYP_标贝女声.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import re 4 | from dataclasses import dataclass 5 | import concurrent.futures 6 | 7 | from tqdm.auto import tqdm 8 | 9 | # submit the form on: https://www.data-baker.com/data/index/TNtts/ 10 | # then you will get the download link 11 | @dataclass 12 | class DataConfig: 13 | dataset_path = './raw_datasets/BZNSYP/Wave' 14 | txt_path = './raw_datasets/BZNSYP/ProsodyLabeling/000001-010000.txt' 15 | output_filelist_path = './filelists/bznsyp.txt' 16 | 17 | data_config = DataConfig() 18 | 19 | def process_filelist(line): 20 | audio_name, text = line.split('\t') 21 | text = re.sub('[#\d]+', '', text) # remove '#' and numbers 22 | input_audio_path = os.path.abspath(os.path.join(data_config.dataset_path, f'{audio_name}.wav')) 23 | if os.path.exists(input_audio_path): 24 | return f'{input_audio_path}|{text}\n' 25 | 26 | if __name__ == '__main__': 27 | filelist = [] 28 | results = [] 29 | 30 | with open(data_config.txt_path, 'r', encoding='utf-8') as f: 31 | for idx, line in enumerate(f): 32 | if idx % 2 == 0: 33 | filelist.append(line.strip()) 34 | 35 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: 36 | futures = [executor.submit(process_filelist, line) for line in filelist] 37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(filelist)): 38 | result = future.result() 39 | if result is not None: 40 | results.append(result) 41 | 42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 45 | f.writelines(results) -------------------------------------------------------------------------------- /recipes/VCTK_huggingface.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | from pathlib import Path 4 | from dataclasses import dataclass 5 | import concurrent.futures 6 | 7 | from tqdm.auto import tqdm 8 | import pandas as pd 9 | import torchaudio 10 | 11 | # download_link: https://huggingface.co/datasets/CSTR-Edinburgh/vctk/tree/063f48e28abda80b2fdc4d4433af8a99e71bfe16 12 | # other huggingface TTS parquet datasets could use the same script 13 | @dataclass 14 | class DataConfig: 15 | dataset_path = './raw_datasets/VCTK' 16 | output_filelist_path = './filelists/VCTK.txt' 17 | output_audio_path = './raw_datasets/VCTK_audios' # to extract audios from parquet files 18 | 19 | data_config = DataConfig() 20 | 21 | def process_parquet(parquet_path: Path): 22 | df = pd.read_parquet(parquet_path) 23 | filelist = [] 24 | for idx, data in tqdm(df.iterrows(), total=len(df)): 25 | audio = io.BytesIO(data['audio']['bytes']) 26 | audio, sample_rate = torchaudio.load(audio) 27 | text = data['text'] 28 | 29 | path = os.path.abspath(os.path.join(data_config.output_audio_path, data['audio']['path'])) 30 | torchaudio.save(path, audio, sample_rate) 31 | 32 | filelist.append(f'{path}|{text}\n') 33 | 34 | return filelist 35 | 36 | if __name__ == '__main__': 37 | filelist = [] 38 | results = [] 39 | 40 | dataset_path = Path(data_config.dataset_path) 41 | parquets = list(dataset_path.rglob('*.parquet')) 42 | 43 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: 44 | futures = [executor.submit(process_parquet, parquet_path) for parquet_path in parquets] 45 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(parquets)): 46 | result = future.result() 47 | if result is not None: 48 | results.extend(result) 49 | 50 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 51 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 52 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 53 | f.writelines(results) 54 | -------------------------------------------------------------------------------- /recipes/genshin_en_小虫哥ver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from dataclasses import dataclass 4 | import concurrent.futures 5 | 6 | from tqdm.auto import tqdm 7 | import openpyxl # use to open excel. run ! pip install openpyxl 8 | 9 | # download_link: https://www.bilibili.com/read/cv23965717 10 | @dataclass 11 | class DataConfig: 12 | dataset_path = './raw_datasets/Genshin_chinese4.5/原神语音包4.5(英)' 13 | excel_path = './raw_datasets/Genshin_chinese4.5/原神4.5语音包对应文本(英).xlsx' 14 | output_filelist_path = './filelists/genshin_en.txt' 15 | 16 | # 若文本中出现以下字符,基本和语音对不上 17 | FORBIDDEN_TEXTS = ["……", "{NICKNAME}", "#", "(", ")", "♪", "test", "{0}", "█", "*", "█", "+", "Gohus"] 18 | REPLACEMENTS = {"$UNRELEASED": ""} 19 | escaped_forbidden_texts = [re.escape(text) for text in FORBIDDEN_TEXTS] 20 | pattern = re.compile("|".join(escaped_forbidden_texts)) 21 | 22 | data_config = DataConfig() 23 | 24 | def clean_text(text): 25 | cleaned_text = text 26 | if pattern.search(cleaned_text): 27 | return None 28 | for old, new in REPLACEMENTS.items(): 29 | cleaned_text = cleaned_text.replace(old, new) 30 | return text 31 | 32 | def read_excel(excel): 33 | wb = openpyxl.load_workbook(excel) 34 | sheet_names = wb.sheetnames 35 | main_sheet = wb[sheet_names[0]] 36 | npc_names = [cell.value for cell in main_sheet['B'] if cell.value][1:] 37 | npc_audio_number = [cell.value for cell in main_sheet['C'] if cell.value][1:] 38 | return wb, npc_names, npc_audio_number 39 | 40 | def process_filelist(data): 41 | audio_path, text, npc_path = data 42 | input_audio_path = os.path.abspath(os.path.join(npc_path, audio_path)) 43 | if os.path.exists(input_audio_path): 44 | text = clean_text(text) 45 | if text is not None: 46 | return f'{input_audio_path}|{text}\n' 47 | 48 | if __name__ == '__main__': 49 | wb, npc_names, npc_audio_number = read_excel(data_config.excel_path) 50 | datas_list = [] 51 | results = [] 52 | 53 | for index, npc_name in enumerate(tqdm(npc_names)): 54 | sheet = wb[npc_name] 55 | audio_names = [cell.value for cell in sheet['C'] if cell.value][1:] 56 | texts = [cell.value for cell in sheet['D'] if cell.value][1:] 57 | npc_path = os.path.join(data_config.dataset_path, npc_name) 58 | datas_list.extend([(audio_name, text, npc_path) for audio_name, text in zip(audio_names, texts)]) 59 | 60 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: 61 | futures = [executor.submit(process_filelist, data) for data in datas_list] 62 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(datas_list)): 63 | result = future.result() 64 | if result is not None: 65 | results.append(result) 66 | 67 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 68 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 69 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 70 | f.writelines(results) -------------------------------------------------------------------------------- /recipes/genshin_zh_小虫哥ver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from dataclasses import dataclass 4 | import concurrent.futures 5 | 6 | from tqdm.auto import tqdm 7 | import openpyxl # use to open excel. run ! pip install openpyxl 8 | 9 | # download_link: https://www.bilibili.com/read/cv23965717 10 | @dataclass 11 | class DataConfig: 12 | dataset_path = './raw_datasets/Genshin_chinese4.5/原神语音包4.5(中)' 13 | excel_path = './raw_datasets/Genshin_chinese4.5/原神4.5语音包对应文本(中).xlsx' 14 | output_filelist_path = './filelists/genshin_zh.txt' 15 | 16 | # 若文本中出现以下字符,基本和语音对不上 17 | FORBIDDEN_TEXTS = ["……", "{NICKNAME}", "#", "(", ")", "♪", "test", "{0}", "█", "*", "█", "+", "Gohus"] 18 | REPLACEMENTS = {"$UNRELEASED": ""} 19 | escaped_forbidden_texts = [re.escape(text) for text in FORBIDDEN_TEXTS] 20 | pattern = re.compile("|".join(escaped_forbidden_texts)) 21 | 22 | data_config = DataConfig() 23 | 24 | def clean_text(text): 25 | cleaned_text = text 26 | # 删去所有包含英文的台词 27 | if re.search(r'[A-Za-z0-9]', cleaned_text): 28 | return None 29 | if pattern.search(cleaned_text): 30 | return None 31 | for old, new in REPLACEMENTS.items(): 32 | cleaned_text = cleaned_text.replace(old, new) 33 | return text 34 | 35 | def read_excel(excel): 36 | wb = openpyxl.load_workbook(excel) 37 | sheet_names = wb.sheetnames 38 | main_sheet = wb[sheet_names[0]] 39 | npc_names = [cell.value for cell in main_sheet['B'] if cell.value][1:] 40 | npc_audio_number = [cell.value for cell in main_sheet['C'] if cell.value][1:] 41 | return wb, npc_names, npc_audio_number 42 | 43 | def process_filelist(data): 44 | audio_path, text, npc_path = data 45 | input_audio_path = os.path.abspath(os.path.join(npc_path, audio_path)) 46 | if os.path.exists(input_audio_path): 47 | text = clean_text(text) 48 | if text is not None: 49 | return f'{input_audio_path}|{text}\n' 50 | 51 | if __name__ == '__main__': 52 | wb, npc_names, npc_audio_number = read_excel(data_config.excel_path) 53 | datas_list = [] 54 | results = [] 55 | 56 | for index, npc_name in enumerate(tqdm(npc_names)): 57 | sheet = wb[npc_name] 58 | audio_names = [cell.value for cell in sheet['C'] if cell.value][1:] 59 | texts = [cell.value for cell in sheet['D'] if cell.value][1:] 60 | npc_path = os.path.join(data_config.dataset_path, npc_name) 61 | datas_list.extend([(audio_name, text, npc_path) for audio_name, text in zip(audio_names, texts)]) 62 | 63 | with concurrent.futures.ProcessPoolExecutor(max_workers=2) as executor: 64 | futures = [executor.submit(process_filelist, data) for data in datas_list] 65 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(datas_list)): 66 | result = future.result() 67 | if result is not None: 68 | results.append(result) 69 | 70 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 71 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 72 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 73 | f.writelines(results) -------------------------------------------------------------------------------- /recipes/hifi_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from pathlib import Path 4 | from dataclasses import dataclass 5 | import concurrent.futures 6 | 7 | from tqdm.auto import tqdm 8 | 9 | # download_link: https://www.openslr.org/109/ 10 | @dataclass 11 | class DataConfig: 12 | dataset_path = './raw_datasets/hi_fi_tts_v0' 13 | output_filelist_path = './filelists/hifi_tts.txt' 14 | 15 | data_config = DataConfig() 16 | 17 | def process_filelist(speaker): 18 | filelist = [] 19 | with open(speaker, 'r', encoding='utf-8') as f: 20 | for line in f: 21 | line = json.loads(line.strip()) 22 | audio_path = os.path.abspath(os.path.join(data_config.dataset_path, line['audio_filepath'])) 23 | text = line['text_normalized'] 24 | if os.path.exists(audio_path): 25 | filelist.append(f'{audio_path}|{text}\n') 26 | return filelist 27 | 28 | if __name__ == '__main__': 29 | filelist = [] 30 | results = [] 31 | 32 | dataset_path = Path(data_config.dataset_path) 33 | speakers = list(dataset_path.rglob('*.json')) 34 | 35 | with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: 36 | futures = [executor.submit(process_filelist, speaker) for speaker in speakers] 37 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(speakers)): 38 | result = future.result() 39 | if result is not None: 40 | results.extend(result) 41 | 42 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 43 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 44 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 45 | f.writelines(results) -------------------------------------------------------------------------------- /recipes/libriTTS.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from dataclasses import dataclass 4 | import concurrent.futures 5 | 6 | from tqdm.auto import tqdm 7 | 8 | # download_link: https://openslr.org/60/ 9 | @dataclass 10 | class DataConfig: 11 | dataset_path = './raw_datasets/LibriTTS/train-other-500' 12 | output_filelist_path = './filelists/libri_tts.txt' 13 | 14 | data_config = DataConfig() 15 | 16 | def process_filelist(wav_path: Path): 17 | text_path = wav_path.with_suffix('.normalized.txt') 18 | if text_path.exists(): 19 | with open(text_path, 'r', encoding='utf-8') as f: 20 | text = f.read().strip() 21 | return f'{wav_path.as_posix()}|{text}\n' 22 | 23 | if __name__ == '__main__': 24 | filelist = [] 25 | results = [] 26 | 27 | dataset_path = Path(data_config.dataset_path) 28 | waves = list(dataset_path.rglob('*.wav')) 29 | 30 | with concurrent.futures.ProcessPoolExecutor(max_workers=8) as executor: 31 | futures = [executor.submit(process_filelist, wav_path) for wav_path in waves] 32 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(waves)): 33 | result = future.result() 34 | if result is not None: 35 | results.append(result) 36 | 37 | # make sure that the parent dir exists, raising error at the last step is quite terrible OVO 38 | os.makedirs(os.path.dirname(data_config.output_filelist_path), exist_ok=True) 39 | with open(data_config.output_filelist_path, 'w', encoding='utf-8') as f: 40 | f.writelines(results) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio 3 | 4 | tqdm 5 | numpy 6 | soundfile # to make sure that torchaudio has at least one valid backend 7 | 8 | tensorboard 9 | 10 | # for monotonic_align 11 | numba 12 | 13 | # ODE-solver 14 | torchdiffeq 15 | 16 | # for g2p 17 | # chinese 18 | pypinyin 19 | jieba 20 | # english 21 | eng_to_ipa 22 | unidecode 23 | inflect 24 | # japanese 25 | # if pyopenjtalk fail to download open_jtalk_dic_utf_8-1.11.tar.gz, manually download and unzip the file below 26 | # https://github.com/r9y9/open_jtalk/releases/download/v1.11.1/open_jtalk_dic_utf_8-1.11.tar.gz 27 | # and set os.environ['OPEN_JTALK_DICT_DIR'] to the folder path 28 | pyopenjtalk-prebuilt # if using python >= 3.12, install pyopenjtalk instead 29 | 30 | # for webui 31 | # gradio 32 | # matplotlib 33 | 34 | -------------------------------------------------------------------------------- /text/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2017 Keith Ito 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from text import cleaners 3 | from text.symbols import symbols 4 | 5 | 6 | # Mappings from symbol to numeric ID and vice versa: 7 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 8 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} 9 | 10 | 11 | def text_to_sequence(text, symbols, cleaner_names): 12 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 13 | Args: 14 | text: string to convert to a sequence 15 | cleaner_names: names of the cleaner functions to run the text through 16 | Returns: 17 | List of integers corresponding to the symbols in the text 18 | ''' 19 | sequence = [] 20 | symbol_to_id = {s: i for i, s in enumerate(symbols)} 21 | clean_text = _clean_text(text, cleaner_names) 22 | print(clean_text) 23 | print(f" length:{len(clean_text)}") 24 | for symbol in clean_text: 25 | if symbol not in symbol_to_id.keys(): 26 | continue 27 | symbol_id = symbol_to_id[symbol] 28 | sequence += [symbol_id] 29 | print(f" length:{len(sequence)}") 30 | return sequence 31 | 32 | 33 | def cleaned_text_to_sequence(cleaned_text): 34 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 35 | Args: 36 | text: string to convert to a sequence 37 | Returns: 38 | List of integers corresponding to the symbols in the text 39 | ''' 40 | # symbol_to_id = {s: i for i, s in enumerate(symbols)} 41 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text if symbol in _symbol_to_id.keys()] 42 | return sequence 43 | 44 | def cleaned_text_to_sequence_chinese(cleaned_text): 45 | '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 46 | Args: 47 | text: string to convert to a sequence 48 | Returns: 49 | List of integers corresponding to the symbols in the text 50 | ''' 51 | # symbol_to_id = {s: i for i, s in enumerate(symbols)} 52 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text.split(' ') if symbol in _symbol_to_id.keys()] 53 | return sequence 54 | 55 | 56 | def sequence_to_text(sequence): 57 | '''Converts a sequence of IDs back to a string''' 58 | result = '' 59 | for symbol_id in sequence: 60 | s = _id_to_symbol[symbol_id] 61 | result += s 62 | return result 63 | 64 | 65 | def _clean_text(text, cleaner_names): 66 | for name in cleaner_names: 67 | cleaner = getattr(cleaners, name) 68 | if not cleaner: 69 | raise Exception('Unknown cleaner: %s' % name) 70 | text = cleaner(text) 71 | return text 72 | -------------------------------------------------------------------------------- /text/cleaners.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from text.english import english_to_ipa2 4 | from text.mandarin import chinese_to_cnm3 5 | from text.japanese import japanese_to_ipa2 6 | 7 | language_module_map = {"PAD":0, "ZH": 1, "EN": 2, "JA": 3} 8 | 9 | # 预编译正则表达式 10 | ZH_PATTERN = re.compile(r'[\u3400-\u4DBF\u4e00-\u9FFF\uF900-\uFAFF\u3000-\u303F]') 11 | EN_PATTERN = re.compile(r'[a-zA-Z.,!?\'"(){}[\]<>:;@#$%^&*-_+=/\\|~`]+') 12 | JP_PATTERN = re.compile(r'[\u3040-\u309F\u30A0-\u30FF\u4E00-\u9FAF\u31F0-\u31FF\uFF00-\uFFEF\u3000-\u303F]') 13 | CLEANER_PATTERN = re.compile(r'\[(ZH|EN|JA)\]') 14 | 15 | def detect_language(text: str, prev_lang=None): 16 | """ 17 | 根据给定的文本检测语言 18 | 19 | :param text: 输入文本 20 | :param prev_lang: 上一个检测到的语言 21 | :return: 'ZH' for Chinese, 'EN' for English, 'JA' for Japanese, or prev_lang for spaces 22 | """ 23 | if ZH_PATTERN.search(text): return 'ZH' 24 | if EN_PATTERN.search(text): return 'EN' 25 | if JP_PATTERN.search(text): return 'JA' 26 | if text.isspace(): return prev_lang # 若是空格,则返回前一个语言 27 | return None 28 | 29 | # auto detect language using re 30 | def cjke_cleaners4(text: str): 31 | """ 32 | 根据文本内容自动检测语言并转换为IPA音标 33 | 34 | :param text: 输入文本 35 | :return: 转换为IPA音标的文本 36 | """ 37 | text = CLEANER_PATTERN.sub('', text) 38 | pointer = 0 39 | output = '' 40 | current_language = detect_language(text[pointer]) 41 | 42 | while pointer < len(text): 43 | temp_text = '' 44 | while pointer < len(text) and detect_language(text[pointer], current_language) == current_language: 45 | temp_text += text[pointer] 46 | pointer += 1 47 | if current_language == 'ZH': 48 | output += chinese_to_cnm3(temp_text) 49 | elif current_language == 'JA': 50 | output += japanese_to_ipa2(temp_text) 51 | elif current_language == 'EN': 52 | output += english_to_ipa2(temp_text) 53 | if pointer < len(text): 54 | current_language = detect_language(text[pointer]) 55 | 56 | output = re.sub(r'\s+$', '', output) 57 | output = re.sub(r'([^\.,!\?\-…~])$', r'\1.', output) 58 | return output 59 | -------------------------------------------------------------------------------- /text/cn2an/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.5.22" 2 | 3 | from .cn2an import Cn2An 4 | from .an2cn import An2Cn 5 | from .transform import Transform 6 | 7 | cn2an = Cn2An().cn2an 8 | an2cn = An2Cn().an2cn 9 | transform = Transform().transform 10 | 11 | __all__ = [ 12 | "__version__", 13 | "cn2an", 14 | "an2cn", 15 | "transform" 16 | ] -------------------------------------------------------------------------------- /text/cn2an/an2cn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from warnings import warn 3 | 4 | # from proces import preprocess 5 | 6 | from .conf import NUMBER_LOW_AN2CN, NUMBER_UP_AN2CN, UNIT_LOW_ORDER_AN2CN, UNIT_UP_ORDER_AN2CN 7 | 8 | 9 | class An2Cn(object): 10 | def __init__(self) -> None: 11 | self.all_num = "0123456789" 12 | self.number_low = NUMBER_LOW_AN2CN 13 | self.number_up = NUMBER_UP_AN2CN 14 | self.mode_list = ["low", "up", "rmb", "direct"] 15 | 16 | def an2cn(self, inputs: Union[str, int, float] = None, mode: str = "low") -> str: 17 | """阿拉伯数字转中文数字 18 | 19 | :param inputs: 阿拉伯数字 20 | :param mode: low 小写数字,up 大写数字,rmb 人民币大写,direct 直接转化 21 | :return: 中文数字 22 | """ 23 | if inputs is not None and inputs != "": 24 | if mode not in self.mode_list: 25 | raise ValueError(f"mode 仅支持 {str(self.mode_list)} !") 26 | 27 | # 将数字转化为字符串,这里会有Python会自动做转化 28 | # 1. -> 1.0 1.00 -> 1.0 -0 -> 0 29 | if not isinstance(inputs, str): 30 | inputs = self.__number_to_string(inputs) 31 | 32 | # 数据预处理: 33 | # 1. 繁体转简体 34 | # 2. 全角转半角 35 | # inputs = preprocess(inputs, pipelines=[ 36 | # "traditional_to_simplified", 37 | # "full_angle_to_half_angle" 38 | # ]) 39 | 40 | # 检查数据是否有效 41 | self.__check_inputs_is_valid(inputs) 42 | 43 | # 判断正负 44 | if inputs[0] == "-": 45 | sign = "负" 46 | inputs = inputs[1:] 47 | else: 48 | sign = "" 49 | 50 | if mode == "direct": 51 | output = self.__direct_convert(inputs) 52 | else: 53 | # 切割整数部分和小数部分 54 | split_result = inputs.split(".") 55 | len_split_result = len(split_result) 56 | if len_split_result == 1: 57 | # 不包含小数的输入 58 | integer_data = split_result[0] 59 | if mode == "rmb": 60 | output = self.__integer_convert(integer_data, "up") + "元整" 61 | else: 62 | output = self.__integer_convert(integer_data, mode) 63 | elif len_split_result == 2: 64 | # 包含小数的输入 65 | integer_data, decimal_data = split_result 66 | if mode == "rmb": 67 | int_data = self.__integer_convert(integer_data, "up") 68 | dec_data = self.__decimal_convert(decimal_data, "up") 69 | len_dec_data = len(dec_data) 70 | 71 | if len_dec_data == 0: 72 | output = int_data + "元整" 73 | elif len_dec_data == 1: 74 | raise ValueError(f"异常输出:{dec_data}") 75 | elif len_dec_data == 2: 76 | if dec_data[1] != "零": 77 | if int_data == "零": 78 | output = dec_data[1] + "角" 79 | else: 80 | output = int_data + "元" + dec_data[1] + "角" 81 | else: 82 | output = int_data + "元整" 83 | else: 84 | if dec_data[1] != "零": 85 | if dec_data[2] != "零": 86 | if int_data == "零": 87 | output = dec_data[1] + "角" + dec_data[2] + "分" 88 | else: 89 | output = int_data + "元" + dec_data[1] + "角" + dec_data[2] + "分" 90 | else: 91 | if int_data == "零": 92 | output = dec_data[1] + "角" 93 | else: 94 | output = int_data + "元" + dec_data[1] + "角" 95 | else: 96 | if dec_data[2] != "零": 97 | if int_data == "零": 98 | output = dec_data[2] + "分" 99 | else: 100 | output = int_data + "元" + "零" + dec_data[2] + "分" 101 | else: 102 | output = int_data + "元整" 103 | else: 104 | output = self.__integer_convert(integer_data, mode) + self.__decimal_convert(decimal_data, mode) 105 | else: 106 | raise ValueError(f"输入格式错误:{inputs}!") 107 | else: 108 | raise ValueError("输入数据为空!") 109 | 110 | return sign + output 111 | 112 | def __direct_convert(self, inputs: str) -> str: 113 | _output = "" 114 | for d in inputs: 115 | if d == ".": 116 | _output += "点" 117 | else: 118 | _output += self.number_low[int(d)] 119 | return _output 120 | 121 | @staticmethod 122 | def __number_to_string(number_data: Union[int, float]) -> str: 123 | # 小数处理:python 会自动把 0.00005 转化成 5e-05,因此 str(0.00005) != "0.00005" 124 | string_data = str(number_data) 125 | if "e" in string_data: 126 | string_data_list = string_data.split("e") 127 | string_key = string_data_list[0] 128 | string_value = string_data_list[1] 129 | if string_value[0] == "-": 130 | string_data = "0." + "0" * (int(string_value[1:]) - 1) + string_key 131 | else: 132 | string_data = string_key + "0" * int(string_value) 133 | return string_data 134 | 135 | def __check_inputs_is_valid(self, check_data: str) -> None: 136 | # 检查输入数据是否在规定的字典中 137 | all_check_keys = self.all_num + ".-" 138 | for data in check_data: 139 | if data not in all_check_keys: 140 | raise ValueError(f"输入的数据不在转化范围内:{data}!") 141 | 142 | def __integer_convert(self, integer_data: str, mode: str) -> str: 143 | if mode == "low": 144 | numeral_list = NUMBER_LOW_AN2CN 145 | unit_list = UNIT_LOW_ORDER_AN2CN 146 | elif mode == "up": 147 | numeral_list = NUMBER_UP_AN2CN 148 | unit_list = UNIT_UP_ORDER_AN2CN 149 | else: 150 | raise ValueError(f"error mode: {mode}") 151 | 152 | # 去除前面的 0,比如 007 => 7 153 | integer_data = str(int(integer_data)) 154 | 155 | len_integer_data = len(integer_data) 156 | if len_integer_data > len(unit_list): 157 | raise ValueError(f"超出数据范围,最长支持 {len(unit_list)} 位") 158 | 159 | output_an = "" 160 | for i, d in enumerate(integer_data): 161 | if int(d): 162 | output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1] 163 | else: 164 | if not (len_integer_data - i - 1) % 4: 165 | output_an += numeral_list[int(d)] + unit_list[len_integer_data - i - 1] 166 | 167 | if i > 0 and not output_an[-1] == "零": 168 | output_an += numeral_list[int(d)] 169 | 170 | output_an = output_an.replace("零零", "零").replace("零万", "万").replace("零亿", "亿").replace("亿万", "亿") \ 171 | .strip("零") 172 | 173 | # 解决「一十几」问题 174 | if output_an[:2] in ["一十"]: 175 | output_an = output_an[1:] 176 | 177 | # 0 - 1 之间的小数 178 | if not output_an: 179 | output_an = "零" 180 | 181 | return output_an 182 | 183 | def __decimal_convert(self, decimal_data: str, o_mode: str) -> str: 184 | len_decimal_data = len(decimal_data) 185 | 186 | if len_decimal_data > 16: 187 | warn(f"注意:小数部分长度为 {len_decimal_data} ,将自动截取前 16 位有效精度!") 188 | decimal_data = decimal_data[:16] 189 | 190 | if len_decimal_data: 191 | output_an = "点" 192 | else: 193 | output_an = "" 194 | 195 | if o_mode == "low": 196 | numeral_list = NUMBER_LOW_AN2CN 197 | elif o_mode == "up": 198 | numeral_list = NUMBER_UP_AN2CN 199 | else: 200 | raise ValueError(f"error mode: {o_mode}") 201 | 202 | for data in decimal_data: 203 | output_an += numeral_list[int(data)] 204 | return output_an -------------------------------------------------------------------------------- /text/cn2an/conf.py: -------------------------------------------------------------------------------- 1 | NUMBER_CN2AN = { 2 | "零": 0, 3 | "〇": 0, 4 | "一": 1, 5 | "壹": 1, 6 | "幺": 1, 7 | "二": 2, 8 | "贰": 2, 9 | "两": 2, 10 | "三": 3, 11 | "叁": 3, 12 | "四": 4, 13 | "肆": 4, 14 | "五": 5, 15 | "伍": 5, 16 | "六": 6, 17 | "陆": 6, 18 | "七": 7, 19 | "柒": 7, 20 | "八": 8, 21 | "捌": 8, 22 | "九": 9, 23 | "玖": 9, 24 | } 25 | UNIT_CN2AN = { 26 | "十": 10, 27 | "拾": 10, 28 | "百": 100, 29 | "佰": 100, 30 | "千": 1000, 31 | "仟": 1000, 32 | "万": 10000, 33 | "亿": 100000000, 34 | } 35 | UNIT_LOW_AN2CN = { 36 | 10: "十", 37 | 100: "百", 38 | 1000: "千", 39 | 10000: "万", 40 | 100000000: "亿", 41 | } 42 | NUMBER_LOW_AN2CN = { 43 | 0: "零", 44 | 1: "一", 45 | 2: "二", 46 | 3: "三", 47 | 4: "四", 48 | 5: "五", 49 | 6: "六", 50 | 7: "七", 51 | 8: "八", 52 | 9: "九", 53 | } 54 | NUMBER_UP_AN2CN = { 55 | 0: "零", 56 | 1: "壹", 57 | 2: "贰", 58 | 3: "叁", 59 | 4: "肆", 60 | 5: "伍", 61 | 6: "陆", 62 | 7: "柒", 63 | 8: "捌", 64 | 9: "玖", 65 | } 66 | UNIT_LOW_ORDER_AN2CN = [ 67 | "", 68 | "十", 69 | "百", 70 | "千", 71 | "万", 72 | "十", 73 | "百", 74 | "千", 75 | "亿", 76 | "十", 77 | "百", 78 | "千", 79 | "万", 80 | "十", 81 | "百", 82 | "千", 83 | ] 84 | UNIT_UP_ORDER_AN2CN = [ 85 | "", 86 | "拾", 87 | "佰", 88 | "仟", 89 | "万", 90 | "拾", 91 | "佰", 92 | "仟", 93 | "亿", 94 | "拾", 95 | "佰", 96 | "仟", 97 | "万", 98 | "拾", 99 | "佰", 100 | "仟", 101 | ] 102 | STRICT_CN_NUMBER = { 103 | "零": "零", 104 | "一": "一壹", 105 | "二": "二贰", 106 | "三": "三叁", 107 | "四": "四肆", 108 | "五": "五伍", 109 | "六": "六陆", 110 | "七": "七柒", 111 | "八": "八捌", 112 | "九": "九玖", 113 | "十": "十拾", 114 | "百": "百佰", 115 | "千": "千仟", 116 | "万": "万", 117 | "亿": "亿", 118 | } 119 | NORMAL_CN_NUMBER = { 120 | "零": "零〇", 121 | "一": "一壹幺", 122 | "二": "二贰两", 123 | "三": "三叁仨", 124 | "四": "四肆", 125 | "五": "五伍", 126 | "六": "六陆", 127 | "七": "七柒", 128 | "八": "八捌", 129 | "九": "九玖", 130 | "十": "十拾", 131 | "百": "百佰", 132 | "千": "千仟", 133 | "万": "万", 134 | "亿": "亿", 135 | } -------------------------------------------------------------------------------- /text/cn2an/transform.py: -------------------------------------------------------------------------------- 1 | import re 2 | from warnings import warn 3 | 4 | from .cn2an import Cn2An 5 | from .an2cn import An2Cn 6 | from .conf import UNIT_CN2AN 7 | 8 | 9 | class Transform(object): 10 | def __init__(self) -> None: 11 | self.all_num = "零一二三四五六七八九" 12 | self.all_unit = "".join(list(UNIT_CN2AN.keys())) 13 | self.cn2an = Cn2An().cn2an 14 | self.an2cn = An2Cn().an2cn 15 | self.cn_pattern = f"负?([{self.all_num}{self.all_unit}]+点)?[{self.all_num}{self.all_unit}]+" 16 | self.smart_cn_pattern = f"-?([0-9]+.)?[0-9]+[{self.all_unit}]+" 17 | 18 | def transform(self, inputs: str, method: str = "cn2an") -> str: 19 | if method == "cn2an": 20 | inputs = inputs.replace("廿", "二十").replace("半", "0.5").replace("两", "2") 21 | # date 22 | inputs = re.sub( 23 | fr"((({self.smart_cn_pattern})|({self.cn_pattern}))年)?([{self.all_num}十]+月)?([{self.all_num}十]+日)?", 24 | lambda x: self.__sub_util(x.group(), "cn2an", "date"), inputs) 25 | # fraction 26 | inputs = re.sub(fr"{self.cn_pattern}分之{self.cn_pattern}", 27 | lambda x: self.__sub_util(x.group(), "cn2an", "fraction"), inputs) 28 | # percent 29 | inputs = re.sub(fr"百分之{self.cn_pattern}", 30 | lambda x: self.__sub_util(x.group(), "cn2an", "percent"), inputs) 31 | # celsius 32 | inputs = re.sub(fr"{self.cn_pattern}摄氏度", 33 | lambda x: self.__sub_util(x.group(), "cn2an", "celsius"), inputs) 34 | # number 35 | output = re.sub(self.cn_pattern, 36 | lambda x: self.__sub_util(x.group(), "cn2an", "number"), inputs) 37 | 38 | elif method == "an2cn": 39 | # date 40 | inputs = re.sub(r"(\d{2,4}年)?(\d{1,2}月)?(\d{1,2}日)?", 41 | lambda x: self.__sub_util(x.group(), "an2cn", "date"), inputs) 42 | # fraction 43 | inputs = re.sub(r"\d+/\d+", 44 | lambda x: self.__sub_util(x.group(), "an2cn", "fraction"), inputs) 45 | # percent 46 | inputs = re.sub(r"-?(\d+\.)?\d+%", 47 | lambda x: self.__sub_util(x.group(), "an2cn", "percent"), inputs) 48 | # celsius 49 | inputs = re.sub(r"\d+℃", 50 | lambda x: self.__sub_util(x.group(), "an2cn", "celsius"), inputs) 51 | # number 52 | output = re.sub(r"-?(\d+\.)?\d+", 53 | lambda x: self.__sub_util(x.group(), "an2cn", "number"), inputs) 54 | else: 55 | raise ValueError(f"error method: {method}, only support 'cn2an' and 'an2cn'!") 56 | 57 | return output 58 | 59 | def __sub_util(self, inputs, method: str = "cn2an", sub_mode: str = "number") -> str: 60 | try: 61 | if inputs: 62 | if method == "cn2an": 63 | if sub_mode == "date": 64 | return re.sub(fr"(({self.smart_cn_pattern})|({self.cn_pattern}))", 65 | lambda x: str(self.cn2an(x.group(), "smart")), inputs) 66 | elif sub_mode == "fraction": 67 | if inputs[0] != "百": 68 | frac_result = re.sub(self.cn_pattern, 69 | lambda x: str(self.cn2an(x.group(), "smart")), inputs) 70 | numerator, denominator = frac_result.split("分之") 71 | return f"{denominator}/{numerator}" 72 | else: 73 | return inputs 74 | elif sub_mode == "percent": 75 | return re.sub(f"(?<=百分之){self.cn_pattern}", 76 | lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("百分之", "") + "%" 77 | elif sub_mode == "celsius": 78 | return re.sub(f"{self.cn_pattern}(?=摄氏度)", 79 | lambda x: str(self.cn2an(x.group(), "smart")), inputs).replace("摄氏度", "℃") 80 | elif sub_mode == "number": 81 | return str(self.cn2an(inputs, "smart")) 82 | else: 83 | raise Exception(f"error sub_mode: {sub_mode} !") 84 | else: 85 | if sub_mode == "date": 86 | inputs = re.sub(r"\d+(?=年)", 87 | lambda x: self.an2cn(x.group(), "direct"), inputs) 88 | return re.sub(r"\d+", 89 | lambda x: self.an2cn(x.group(), "low"), inputs) 90 | elif sub_mode == "fraction": 91 | frac_result = re.sub(r"\d+", lambda x: self.an2cn(x.group(), "low"), inputs) 92 | numerator, denominator = frac_result.split("/") 93 | return f"{denominator}分之{numerator}" 94 | elif sub_mode == "celsius": 95 | return self.an2cn(inputs[:-1], "low") + "摄氏度" 96 | elif sub_mode == "percent": 97 | return "百分之" + self.an2cn(inputs[:-1], "low") 98 | elif sub_mode == "number": 99 | return self.an2cn(inputs, "low") 100 | else: 101 | raise Exception(f"error sub_mode: {sub_mode} !") 102 | except Exception as e: 103 | warn(str(e)) 104 | return inputs -------------------------------------------------------------------------------- /text/custom_pypinyin_dict/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /text/custom_pypinyin_dict/cc_cedict_3.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | # Warning: Auto-generated file, don't edit. 5 | phrases_dict = { 6 | '𰻝𰻝面': [['biáng'], ['biáng'], ['miàn']], 7 | } 8 | 9 | 10 | from pypinyin import load_phrases_dict 11 | 12 | 13 | def load(): 14 | load_phrases_dict(phrases_dict) 15 | -------------------------------------------------------------------------------- /text/custom_pypinyin_dict/genshin.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | phrases_dict = { 5 | '㐖毒': [['xié'], ['dú']], 6 | '若陀': [['rě'], ['tuó']], 7 | '平藏': [['píng'], ['zàng']], 8 | '派蒙': [['pài'], ['méng']], 9 | '安柏': [['ān'], ['bó']], 10 | '一斗': [['yī'], ['dǒu']] 11 | } -------------------------------------------------------------------------------- /text/custom_pypinyin_dict/phrase_pinyin_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from __future__ import unicode_literals 3 | 4 | from pypinyin import load_phrases_dict 5 | 6 | from text.custom_pypinyin_dict import cc_cedict_0 7 | from text.custom_pypinyin_dict import cc_cedict_1 8 | from text.custom_pypinyin_dict import cc_cedict_2 9 | from text.custom_pypinyin_dict import cc_cedict_3 10 | from text.custom_pypinyin_dict import genshin 11 | 12 | phrases_dict = {} 13 | phrases_dict.update(cc_cedict_0.phrases_dict) 14 | phrases_dict.update(cc_cedict_1.phrases_dict) 15 | phrases_dict.update(cc_cedict_2.phrases_dict) 16 | phrases_dict.update(cc_cedict_3.phrases_dict) 17 | phrases_dict.update(genshin.phrases_dict) 18 | 19 | def load(): 20 | load_phrases_dict(phrases_dict) 21 | print("加载自定义词典成功") 22 | 23 | if __name__ == '__main__': 24 | print(phrases_dict) -------------------------------------------------------------------------------- /text/english.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | ''' 4 | Cleaners are transformations that run over the input text at both training and eval time. 5 | 6 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 7 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 8 | 1. "english_cleaners" for English text 9 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 10 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 11 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 12 | the symbols in symbols.py to match your data). 13 | ''' 14 | 15 | 16 | # Regular expression matching whitespace: 17 | 18 | 19 | import re 20 | import inflect 21 | from unidecode import unidecode 22 | import eng_to_ipa as ipa 23 | _inflect = inflect.engine() 24 | _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') 25 | _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') 26 | _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') 27 | _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') 28 | _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') 29 | _number_re = re.compile(r'[0-9]+') 30 | 31 | # List of (regular expression, replacement) pairs for abbreviations: 32 | _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ 33 | ('mrs', 'misess'), 34 | ('mr', 'mister'), 35 | ('dr', 'doctor'), 36 | ('st', 'saint'), 37 | ('co', 'company'), 38 | ('jr', 'junior'), 39 | ('maj', 'major'), 40 | ('gen', 'general'), 41 | ('drs', 'doctors'), 42 | ('rev', 'reverend'), 43 | ('lt', 'lieutenant'), 44 | ('hon', 'honorable'), 45 | ('sgt', 'sergeant'), 46 | ('capt', 'captain'), 47 | ('esq', 'esquire'), 48 | ('ltd', 'limited'), 49 | ('col', 'colonel'), 50 | ('ft', 'fort'), 51 | ]] 52 | 53 | 54 | # List of (ipa, lazy ipa) pairs: 55 | _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 56 | ('r', 'ɹ'), 57 | ('æ', 'e'), 58 | ('ɑ', 'a'), 59 | ('ɔ', 'o'), 60 | ('ð', 'z'), 61 | ('θ', 's'), 62 | ('ɛ', 'e'), 63 | ('ɪ', 'i'), 64 | ('ʊ', 'u'), 65 | ('ʒ', 'ʥ'), 66 | ('ʤ', 'ʥ'), 67 | ('ˈ', '↓'), 68 | ]] 69 | 70 | # List of (ipa, lazy ipa2) pairs: 71 | _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 72 | ('r', 'ɹ'), 73 | ('ð', 'z'), 74 | ('θ', 's'), 75 | ('ʒ', 'ʑ'), 76 | ('ʤ', 'dʑ'), 77 | ('ˈ', '↓'), 78 | ]] 79 | 80 | # List of (ipa, ipa2) pairs 81 | _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 82 | ('r', 'ɹ'), 83 | ('ʤ', 'dʒ'), 84 | ('ʧ', 'tʃ') 85 | ]] 86 | 87 | 88 | def expand_abbreviations(text): 89 | for regex, replacement in _abbreviations: 90 | text = re.sub(regex, replacement, text) 91 | return text 92 | 93 | 94 | def collapse_whitespace(text): 95 | return re.sub(r'\s+', ' ', text) 96 | 97 | 98 | def _remove_commas(m): 99 | return m.group(1).replace(',', '') 100 | 101 | 102 | def _expand_decimal_point(m): 103 | return m.group(1).replace('.', ' point ') 104 | 105 | 106 | def _expand_dollars(m): 107 | match = m.group(1) 108 | parts = match.split('.') 109 | if len(parts) > 2: 110 | return match + ' dollars' # Unexpected format 111 | dollars = int(parts[0]) if parts[0] else 0 112 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 113 | if dollars and cents: 114 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 115 | cent_unit = 'cent' if cents == 1 else 'cents' 116 | return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) 117 | elif dollars: 118 | dollar_unit = 'dollar' if dollars == 1 else 'dollars' 119 | return '%s %s' % (dollars, dollar_unit) 120 | elif cents: 121 | cent_unit = 'cent' if cents == 1 else 'cents' 122 | return '%s %s' % (cents, cent_unit) 123 | else: 124 | return 'zero dollars' 125 | 126 | 127 | def _expand_ordinal(m): 128 | return _inflect.number_to_words(m.group(0)) 129 | 130 | 131 | def _expand_number(m): 132 | num = int(m.group(0)) 133 | if num > 1000 and num < 3000: 134 | if num == 2000: 135 | return 'two thousand' 136 | elif num > 2000 and num < 2010: 137 | return 'two thousand ' + _inflect.number_to_words(num % 100) 138 | elif num % 100 == 0: 139 | return _inflect.number_to_words(num // 100) + ' hundred' 140 | else: 141 | return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') 142 | else: 143 | return _inflect.number_to_words(num, andword='') 144 | 145 | 146 | def normalize_numbers(text): 147 | text = re.sub(_comma_number_re, _remove_commas, text) 148 | text = re.sub(_pounds_re, r'\1 pounds', text) 149 | text = re.sub(_dollars_re, _expand_dollars, text) 150 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 151 | text = re.sub(_ordinal_re, _expand_ordinal, text) 152 | text = re.sub(_number_re, _expand_number, text) 153 | return text 154 | 155 | 156 | def mark_dark_l(text): 157 | return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) 158 | 159 | 160 | def english_to_ipa(text): 161 | text = unidecode(text).lower() 162 | text = expand_abbreviations(text) 163 | text = normalize_numbers(text) 164 | phonemes = ipa.convert(text) 165 | phonemes = collapse_whitespace(phonemes) 166 | return phonemes 167 | 168 | 169 | def english_to_ipa2(text): 170 | text = english_to_ipa(text) 171 | text = mark_dark_l(text) 172 | for regex, replacement in _ipa_to_ipa2: 173 | text = re.sub(regex, replacement, text) 174 | return list(text.replace('...', '…')) 175 | 176 | -------------------------------------------------------------------------------- /text/japanese.py: -------------------------------------------------------------------------------- 1 | import re 2 | from unidecode import unidecode 3 | import pyopenjtalk 4 | 5 | 6 | # Regular expression matching Japanese without punctuation marks: 7 | _japanese_characters = re.compile( 8 | r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 9 | 10 | # Regular expression matching non-Japanese characters or punctuation marks: 11 | _japanese_marks = re.compile( 12 | r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') 13 | 14 | # List of (symbol, Japanese) pairs for marks: 15 | _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ 16 | ('%', 'パーセント') 17 | ]] 18 | 19 | # List of (romaji, ipa) pairs for marks: 20 | _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ 21 | ('ts', 'ʦ'), 22 | ('u', 'ɯ'), 23 | ('j', 'ʥ'), 24 | ('y', 'j'), 25 | ('ni', 'n^i'), 26 | ('nj', 'n^'), 27 | ('hi', 'çi'), 28 | ('hj', 'ç'), 29 | ('f', 'ɸ'), 30 | ('I', 'i*'), 31 | ('U', 'ɯ*'), 32 | ('r', 'ɾ') 33 | ]] 34 | 35 | # List of (romaji, ipa2) pairs for marks: 36 | _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ 37 | ('u', 'ɯ'), 38 | ('ʧ', 'tʃ'), 39 | ('j', 'dʑ'), 40 | ('y', 'j'), 41 | ('ni', 'n^i'), 42 | ('nj', 'n^'), 43 | ('hi', 'çi'), 44 | ('hj', 'ç'), 45 | ('f', 'ɸ'), 46 | ('I', 'i*'), 47 | ('U', 'ɯ*'), 48 | ('r', 'ɾ') 49 | ]] 50 | 51 | # List of (consonant, sokuon) pairs: 52 | _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 53 | (r'Q([↑↓]*[kg])', r'k#\1'), 54 | (r'Q([↑↓]*[tdjʧ])', r't#\1'), 55 | (r'Q([↑↓]*[sʃ])', r's\1'), 56 | (r'Q([↑↓]*[pb])', r'p#\1') 57 | ]] 58 | 59 | # List of (consonant, hatsuon) pairs: 60 | _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ 61 | (r'N([↑↓]*[pbm])', r'm\1'), 62 | (r'N([↑↓]*[ʧʥj])', r'n^\1'), 63 | (r'N([↑↓]*[tdn])', r'n\1'), 64 | (r'N([↑↓]*[kg])', r'ŋ\1') 65 | ]] 66 | 67 | 68 | def symbols_to_japanese(text): 69 | for regex, replacement in _symbols_to_japanese: 70 | text = re.sub(regex, replacement, text) 71 | return text 72 | 73 | 74 | def japanese_to_romaji_with_accent(text): 75 | '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' 76 | text = symbols_to_japanese(text) 77 | sentences = re.split(_japanese_marks, text) 78 | marks = re.findall(_japanese_marks, text) 79 | text = '' 80 | for i, sentence in enumerate(sentences): 81 | if re.match(_japanese_characters, sentence): 82 | if text != '': 83 | text += ' ' 84 | labels = pyopenjtalk.extract_fullcontext(sentence) 85 | for n, label in enumerate(labels): 86 | phoneme = re.search(r'\-([^\+]*)\+', label).group(1) 87 | if phoneme not in ['sil', 'pau']: 88 | text += phoneme.replace('ch', 'ʧ').replace('sh', 89 | 'ʃ').replace('cl', 'Q') 90 | else: 91 | continue 92 | # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) 93 | a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) 94 | a2 = int(re.search(r"\+(\d+)\+", label).group(1)) 95 | a3 = int(re.search(r"\+(\d+)/", label).group(1)) 96 | if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: 97 | a2_next = -1 98 | else: 99 | a2_next = int( 100 | re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) 101 | # Accent phrase boundary 102 | if a3 == 1 and a2_next == 1: 103 | text += ' ' 104 | # Falling 105 | elif a1 == 0 and a2_next == a2 + 1: 106 | text += '↓' 107 | # Rising 108 | elif a2 == 1 and a2_next == 2: 109 | text += '↑' 110 | if i < len(marks): 111 | text += unidecode(marks[i]).replace(' ', '') 112 | return text 113 | 114 | 115 | def get_real_sokuon(text): 116 | for regex, replacement in _real_sokuon: 117 | text = re.sub(regex, replacement, text) 118 | return text 119 | 120 | 121 | def get_real_hatsuon(text): 122 | for regex, replacement in _real_hatsuon: 123 | text = re.sub(regex, replacement, text) 124 | return text 125 | 126 | 127 | def japanese_to_ipa(text): 128 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 129 | text = re.sub( 130 | r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 131 | text = get_real_sokuon(text) 132 | text = get_real_hatsuon(text) 133 | for regex, replacement in _romaji_to_ipa: 134 | text = re.sub(regex, replacement, text) 135 | return text 136 | 137 | 138 | def japanese_to_ipa2(text): 139 | text = japanese_to_romaji_with_accent(text).replace('...', '…') 140 | text = get_real_sokuon(text) 141 | text = get_real_hatsuon(text) 142 | for regex, replacement in _romaji_to_ipa2: 143 | text = re.sub(regex, replacement, text) 144 | return list(text) 145 | 146 | 147 | def japanese_to_ipa3(text): 148 | text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 149 | 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') 150 | text = re.sub( 151 | r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) 152 | text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) 153 | return text 154 | 155 | if __name__ == '__main__': 156 | a = japanese_to_romaji_with_accent('こんにちは!はい、元気です。あなたは?') 157 | print(a) 158 | -------------------------------------------------------------------------------- /text/mandarin.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Dict, List 3 | from pypinyin import lazy_pinyin, Style 4 | from .custom_pypinyin_dict import phrase_pinyin_data 5 | import jieba 6 | from .cn2an import an2cn 7 | 8 | # 加载自定义拼音词典数据 9 | phrase_pinyin_data.load() 10 | 11 | # 标点符号正则 12 | PUNC_MAP: Dict[str, str] = { 13 | ":": ",", 14 | ";": ",", 15 | ",": ",", 16 | "。": ".", 17 | "!": "!", 18 | "?": "?", 19 | "\n": ".", 20 | "·": ",", 21 | "、": ",", 22 | "$": ".", 23 | "/": ",", 24 | "“": "'", 25 | "”": "'", 26 | '"': "'", 27 | "‘": "'", 28 | "’": "'", 29 | "(": "'", 30 | ")": "'", 31 | "(": "'", 32 | ")": "'", 33 | "《": "'", 34 | "》": "'", 35 | "【": "'", 36 | "】": "'", 37 | "[": "'", 38 | "]": "'", 39 | "—": "-", 40 | "~": "~", 41 | "「": "'", 42 | "」": "'", 43 | "『": "'", 44 | "』": "'", 45 | } 46 | 47 | # from GPT_SoVITS.text.zh_normalization.text_normlization 48 | PUNC_MAP.update ({ 49 | '/': '每', 50 | '①': '一', 51 | '②': '二', 52 | '③': '三', 53 | '④': '四', 54 | '⑤': '五', 55 | '⑥': '六', 56 | '⑦': '七', 57 | '⑧': '八', 58 | '⑨': '九', 59 | '⑩': '十', 60 | 'α': '阿尔法', 61 | 'β': '贝塔', 62 | 'γ': '伽玛', 63 | 'Γ': '伽玛', 64 | 'δ': '德尔塔', 65 | 'Δ': '德尔塔', 66 | 'ε': '艾普西龙', 67 | 'ζ': '捷塔', 68 | 'η': '依塔', 69 | 'θ': '西塔', 70 | 'Θ': '西塔', 71 | 'ι': '艾欧塔', 72 | 'κ': '喀帕', 73 | 'λ': '拉姆达', 74 | 'Λ': '拉姆达', 75 | 'μ': '缪', 76 | 'ν': '拗', 77 | 'ξ': '克西', 78 | 'Ξ': '克西', 79 | 'ο': '欧米克伦', 80 | 'π': '派', 81 | 'Π': '派', 82 | 'ρ': '肉', 83 | 'ς': '西格玛', 84 | 'σ': '西格玛', 85 | 'Σ': '西格玛', 86 | 'τ': '套', 87 | 'υ': '宇普西龙', 88 | 'φ': '服艾', 89 | 'Φ': '服艾', 90 | 'χ': '器', 91 | 'ψ': '普赛', 92 | 'Ψ': '普赛', 93 | 'ω': '欧米伽', 94 | 'Ω': '欧米伽', 95 | '+': '加', 96 | '-': '减', 97 | '×': '乘', 98 | '÷': '除', 99 | '=': '等', 100 | 101 | "嗯": "恩", 102 | "呣": "母" 103 | }) 104 | 105 | PUNC_TABLE = str.maketrans(PUNC_MAP) 106 | 107 | # 数字正则化 108 | NUMBER_PATTERN: re.Pattern = re.compile(r'\d+(?:\.?\d+)?') 109 | 110 | # 阿拉伯数字转汉字 111 | def replace_number(match: re.Match) -> str: 112 | return an2cn(match.group()) 113 | 114 | def normalize_number(text: str) -> str: 115 | return NUMBER_PATTERN.sub(replace_number, text) 116 | 117 | # get symbols of phones, not used 118 | def load_pinyin_symbols(path): 119 | pinyin_dict={} 120 | temp = [] 121 | with open(path, "r", encoding='utf-8') as f: 122 | content = f.readlines() 123 | for line in content: 124 | cuts = line.strip().split(',') 125 | pinyin = cuts[0] 126 | phones = cuts[1].split(' ') 127 | pinyin_dict[pinyin] = phones 128 | temp.extend(phones) 129 | temp = list(set(temp)) 130 | tone = [] 131 | for phone in temp: 132 | for i in range(1, 6): 133 | phone2 = phone + str(i) 134 | tone.append(phone2) 135 | print(sorted(tone, key=lambda x: len(x))) 136 | return pinyin_dict 137 | 138 | def load_pinyin_dict(path: str) -> Dict[str, List[str]]: 139 | pinyin_dict = {} 140 | with open(path, "r", encoding='utf-8') as f: 141 | for line in f: 142 | key, value = line.strip().split(',', 1) 143 | pinyin_dict[key] = value.split() 144 | return pinyin_dict 145 | 146 | import os 147 | pinyin_dict = load_pinyin_dict(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cnm3', 'ds_CNM3.txt')) 148 | # pinyin_dict = load_pinyin_dict('text/cnm3/ds_CNM3.txt') 149 | 150 | def chinese_to_cnm3(text: str) -> List[str]: 151 | # 标点符号和数字正则化 152 | text = text.translate(PUNC_TABLE) 153 | text = normalize_number(text) 154 | # 过滤掉特殊字符 155 | text = re.sub(r'[#&@“”^_|\\]', '', text) 156 | 157 | words = jieba.lcut(text, cut_all=False) 158 | 159 | phones = [] 160 | for word in words: 161 | pinyin_list: List[str] = lazy_pinyin(word, style=Style.TONE3, neutral_tone_with_five=True) 162 | for pinyin in pinyin_list: 163 | if pinyin[-1].isdigit(): 164 | tone = pinyin[-1] 165 | syllable = pinyin[:-1] 166 | phone = pinyin_dict[syllable] 167 | phones.extend([ph + tone for ph in phone]) 168 | elif pinyin[-1].isalpha(): 169 | pass 170 | else: 171 | phones.extend(pinyin) 172 | 173 | return phones -------------------------------------------------------------------------------- /text/symbols.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Defines the set of symbols used in text input to the model. 3 | ''' 4 | 5 | # japanese_cleaners 6 | # _pad = '_' 7 | # _punctuation = ',.!?-' 8 | # _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧ↓↑ ' 9 | 10 | 11 | '''# japanese_cleaners2 12 | _pad = '_' 13 | _punctuation = ',.!?-~…' 14 | _letters = 'AEINOQUabdefghijkmnoprstuvwyzʃʧʦ↓↑ ' 15 | ''' 16 | 17 | 18 | '''# korean_cleaners 19 | _pad = '_' 20 | _punctuation = ',.!?…~' 21 | _letters = 'ㄱㄴㄷㄹㅁㅂㅅㅇㅈㅊㅋㅌㅍㅎㄲㄸㅃㅆㅉㅏㅓㅗㅜㅡㅣㅐㅔ ' 22 | ''' 23 | 24 | '''# chinese_cleaners 25 | _pad = '_' 26 | _punctuation = ',。!?—…' 27 | _letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ ' 28 | ''' 29 | 30 | # # zh_ja_mixture_cleaners 31 | # _pad = '_' 32 | # _punctuation = ',.!?-~…' 33 | # _letters = 'AEINOQUabdefghijklmnoprstuvwyzʃʧʦɯɹəɥ⁼ʰ`→↓↑ ' 34 | 35 | 36 | '''# sanskrit_cleaners 37 | _pad = '_' 38 | _punctuation = '।' 39 | _letters = 'ँंःअआइईउऊऋएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऽािीुूृॄेैोौ्ॠॢ ' 40 | ''' 41 | 42 | '''# cjks_cleaners 43 | _pad = '_' 44 | _punctuation = ',.!?-~…' 45 | _letters = 'NQabdefghijklmnopstuvwxyzʃʧʥʦɯɹəɥçɸɾβŋɦː⁼ʰ`^#*=→↓↑ ' 46 | ''' 47 | 48 | '''# thai_cleaners 49 | _pad = '_' 50 | _punctuation = '.!? ' 51 | _letters = 'กขฃคฆงจฉชซฌญฎฏฐฑฒณดตถทธนบปผฝพฟภมยรฤลวศษสหฬอฮฯะัาำิีึืุูเแโใไๅๆ็่้๊๋์' 52 | ''' 53 | 54 | # # cjke_cleaners2 55 | _pad = '_' 56 | _punctuation = ',.!?-~…' + "'" 57 | _IPA_letters = 'NQabdefghijklmnopstuvwxyzɑæʃʑçɯɪɔɛɹðəɫɥɸʊɾʒθβŋɦ⁼ʰ`^#*=ˈˌ→↓↑ ' 58 | _CNM3_letters = ['y1', 'y2', 'y3', 'y4', 'y5', 'n1', 'n2', 'n3', 'n4', 'n5', 'p1', 'p2', 'p3', 'p4', 'p5', 'x1', 'x2', 'x3', 'x4', 'x5', 'k1', 'k2', 'k3', 'k4', 'k5', 'l1', 'l2', 'l3', 'l4', 'l5', 'q1', 'q2', 'q3', 'q4', 'q5', 'w1', 'w2', 'w3', 'w4', 'w5', 'E1', 'E2', 'E3', 'E4', 'E5', 'b1', 'b2', 'b3', 'b4', 'b5', 'c1', 'c2', 'c3', 'c4', 'c5', 'z1', 'z2', 'z3', 'z4', 'z5', 'e1', 'e2', 'e3', 'e4', 'e5', 'f1', 'f2', 'f3', 'f4', 'f5', 's1', 's2', 's3', 's4', 's5', 'j1', 'j2', 'j3', 'j4', 'j5', 'o1', 'o2', 'o3', 'o4', 'o5', 'i1', 'i2', 'i3', 'i4', 'i5', 'd1', 'd2', 'd3', 'd4', 'd5', 'm1', 'm2', 'm3', 'm4', 'm5', 't1', 't2', 't3', 't4', 't5', 'h1', 'h2', 'h3', 'h4', 'h5', 'g1', 'g2', 'g3', 'g4', 'g5', 'v1', 'v2', 'v3', 'v4', 'v5', 'r1', 'r2', 'r3', 'r4', 'r5', 'a1', 'a2', 'a3', 'a4', 'a5', 'u1', 'u2', 'u3', 'u4', 'u5', 'I01', 'I02', 'I03', 'I04', 'I05', 'i01', 'i02', 'i03', 'i04', 'i05', 'uo1', 'uo2', 'uo3', 'uo4', 'uo5', 'o01', 'o02', 'o03', 'o04', 'o05', 'U01', 'U02', 'U03', 'U04', 'U05', 'v01', 'v02', 'v03', 'v04', 'v05', 'er1', 'er2', 'er3', 'er4', 'er5', 'A01', 'A02', 'A03', 'A04', 'A05', 'ai1', 'ai2', 'ai3', 'ai4', 'ai5', 'e01', 'e02', 'e03', 'e04', 'e05', 'sh1', 'sh2', 'sh3', 'sh4', 'sh5', 'an1', 'an2', 'an3', 'an4', 'an5', 'ou1', 'ou2', 'ou3', 'ou4', 'ou5', 'ch1', 'ch2', 'ch3', 'ch4', 'ch5', 'a01', 'a02', 'a03', 'a04', 'a05', 'N01', 'N02', 'N03', 'N04', 'N05', 'ao1', 'ao2', 'ao3', 'ao4', 'ao5', 've1', 've2', 've3', 've4', 've5', 'ir1', 'ir2', 'ir3', 'ir4', 'ir5', 'ng1', 'ng2', 'ng3', 'ng4', 'ng5', 'ua1', 'ua2', 'ua3', 'ua4', 'ua5', 'zh1', 'zh2', 'zh3', 'zh4', 'zh5', 'O01', 'O02', 'O03', 'O04', 'O05', 'ie1', 'ie2', 'ie3', 'ie4', 'ie5', 'E01', 'E02', 'E03', 'E04', 'E05', 'ia1', 'ia2', 'ia3', 'ia4', 'ia5', 'iE01', 'iE02', 'iE03', 'iE04', 'iE05', 'ang1', 'ang2', 'ang3', 'ang4', 'ang5', 'ng01', 'ng02', 'ng03', 'ng04', 'ng05', 'io01', 'io02', 'io03', 'io04', 'io05', 'iA01', 'iA02', 'iA03', 'iA04', 'iA05', 'uA01', 'uA02', 'uA03', 'uA04', 'uA05', 'ong1', 'ong2', 'ong3', 'ong4', 'ong5', 'oo01', 'oo02', 'oo03', 'oo04', 'oo05', 'uE01', 'uE02', 'uE03', 'uE04', 'uE05', 'vE01', 'vE02', 'vE03', 'vE04', 'vE05', 'ue01', 'ue02', 'ue03', 'ue04', 'ue05', 'ua01', 'ua02', 'ua03', 'ua04', 'ua05', 'iO01', 'iO02', 'iO03', 'iO04', 'iO05'] 59 | _additional = ['', ''] 60 | # _CNM3_letters = [] 61 | 62 | 63 | '''# shanghainese_cleaners 64 | _pad = '_' 65 | _punctuation = ',.!?…' 66 | _letters = 'abdfghiklmnopstuvyzøŋȵɑɔɕəɤɦɪɿʑʔʰ̩̃ᴀᴇ15678 ' 67 | ''' 68 | 69 | '''# chinese_dialect_cleaners 70 | _pad = '_' 71 | _punctuation = ',.!?~…─' 72 | _letters = '#Nabdefghijklmnoprstuvwxyzæçøŋœȵɐɑɒɓɔɕɗɘəɚɛɜɣɤɦɪɭɯɵɷɸɻɾɿʂʅʊʋʌʏʑʔʦʮʰʷˀː˥˦˧˨˩̥̩̃̚ᴀᴇ↑↓∅ⱼ ' 73 | ''' 74 | 75 | # Export all symbols: 76 | symbols = [_pad] + list(_punctuation) + list(_IPA_letters) + _CNM3_letters + _additional 77 | 78 | # Special symbol ids 79 | SPACE_ID = symbols.index(" ") 80 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 3 | 4 | import torch 5 | import torch.optim as optim 6 | import torch.distributed as dist 7 | from torch.nn.parallel import DistributedDataParallel as DDP 8 | from torch.utils.data import DataLoader 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from tqdm import tqdm 12 | from dataclasses import asdict 13 | 14 | from datas.dataset import StableDataset, collate_fn 15 | from datas.sampler import DistributedBucketSampler 16 | from text import symbols 17 | from config import MelConfig, ModelConfig, TrainConfig 18 | from models.model import StableTTS 19 | 20 | from utils.scheduler import get_cosine_schedule_with_warmup 21 | from utils.load import continue_training 22 | 23 | torch.backends.cudnn.benchmark = True 24 | 25 | def setup(rank, world_size): 26 | os.environ['MASTER_ADDR'] = 'localhost' 27 | os.environ['MASTER_PORT'] = '12345' 28 | dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size) 29 | 30 | def cleanup(): 31 | dist.destroy_process_group() 32 | 33 | def _init_config(model_config: ModelConfig, mel_config: MelConfig, train_config: TrainConfig): 34 | 35 | if not os.path.exists(train_config.model_save_path): 36 | print(f'Creating {train_config.model_save_path}') 37 | os.makedirs(train_config.model_save_path, exist_ok=True) 38 | 39 | def train(rank, world_size): 40 | setup(rank, world_size) 41 | torch.cuda.set_device(rank) 42 | 43 | model_config = ModelConfig() 44 | mel_config = MelConfig() 45 | train_config = TrainConfig() 46 | 47 | _init_config(model_config, mel_config, train_config) 48 | 49 | model = StableTTS(len(symbols), mel_config.n_mels, **asdict(model_config)).to(rank) 50 | 51 | model = DDP(model, device_ids=[rank]) 52 | 53 | train_dataset = StableDataset(train_config.train_dataset_path, mel_config.hop_length) 54 | train_sampler = DistributedBucketSampler(train_dataset, train_config.batch_size, [32,300,400,500,600,700,800,900,1000], num_replicas=world_size, rank=rank) 55 | train_dataloader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True, collate_fn=collate_fn, persistent_workers=True) 56 | 57 | if rank == 0: 58 | writer = SummaryWriter(train_config.log_dir) 59 | 60 | optimizer = optim.AdamW(model.parameters(), lr=train_config.learning_rate) 61 | scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader)) 62 | 63 | # load latest checkpoints if possible 64 | current_epoch = continue_training(train_config.model_save_path, model, optimizer) 65 | 66 | model.train() 67 | for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times 68 | train_dataloader.batch_sampler.set_epoch(epoch) 69 | if rank == 0: 70 | dataloader = tqdm(train_dataloader) 71 | else: 72 | dataloader = train_dataloader 73 | 74 | for batch_idx, datas in enumerate(dataloader): 75 | datas = [data.to(rank, non_blocking=True) for data in datas] 76 | x, x_lengths, y, y_lengths, z, z_lengths = datas 77 | optimizer.zero_grad() 78 | dur_loss, diff_loss, prior_loss, _ = model(x, x_lengths, y, y_lengths, z, z_lengths) 79 | loss = dur_loss + diff_loss + prior_loss 80 | loss.backward() 81 | optimizer.step() 82 | scheduler.step() 83 | 84 | if rank == 0 and batch_idx % train_config.log_interval == 0: 85 | steps = epoch * len(dataloader) + batch_idx 86 | writer.add_scalar("training/diff_loss", diff_loss.item(), steps) 87 | writer.add_scalar("training/dur_loss", dur_loss.item(), steps) 88 | writer.add_scalar("training/prior_loss", prior_loss.item(), steps) 89 | writer.add_scalar("learning_rate/learning_rate", scheduler.get_last_lr()[0], steps) 90 | 91 | if rank == 0 and epoch % train_config.save_interval == 0: 92 | torch.save(model.module.state_dict(), os.path.join(train_config.model_save_path, f'checkpoint_{epoch}.pt')) 93 | torch.save(optimizer.state_dict(), os.path.join(train_config.model_save_path, f'optimizer_{epoch}.pt')) 94 | print(f"Rank {rank}, Epoch {epoch}, Loss {loss.item()}") 95 | 96 | cleanup() 97 | 98 | torch.set_num_threads(1) 99 | torch.set_num_interop_threads(1) 100 | 101 | if __name__ == "__main__": 102 | world_size = torch.cuda.device_count() 103 | torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size) -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/utils/__init__.py -------------------------------------------------------------------------------- /utils/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torchaudio 5 | 6 | class LinearSpectrogram(nn.Module): 7 | def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode): 8 | super().__init__() 9 | 10 | self.n_fft = n_fft 11 | self.win_length = win_length 12 | self.hop_length = hop_length 13 | self.pad = pad 14 | self.center = center 15 | self.pad_mode = pad_mode 16 | 17 | self.register_buffer("window", torch.hann_window(win_length)) 18 | 19 | def forward(self, waveform: Tensor) -> Tensor: 20 | if waveform.ndim == 3: 21 | waveform = waveform.squeeze(1) 22 | waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1) 23 | spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True) 24 | spec = torch.view_as_real(spec) 25 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 26 | return spec 27 | 28 | 29 | class LogMelSpectrogram(nn.Module): 30 | def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale): 31 | super().__init__() 32 | self.sample_rate = sample_rate 33 | self.n_fft = n_fft 34 | self.win_length = win_length 35 | self.hop_length = hop_length 36 | self.f_min = f_min 37 | self.f_max = f_max 38 | self.pad = pad 39 | self.n_mels = n_mels 40 | self.center = center 41 | self.pad_mode = pad_mode 42 | self.mel_scale = mel_scale 43 | 44 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode) 45 | self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale) 46 | 47 | def compress(self, x: Tensor) -> Tensor: 48 | return torch.log(torch.clamp(x, min=1e-5)) 49 | 50 | def decompress(self, x: Tensor) -> Tensor: 51 | return torch.exp(x) 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | linear_spec = self.spectrogram(x) 55 | x = self.mel_scale(linear_spec) 56 | x = self.compress(x) 57 | return x 58 | 59 | def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor: 60 | try: 61 | y, sr = torchaudio.load(audio_path) 62 | except Exception as e: 63 | print(str(e)) 64 | return None 65 | 66 | y.to(device) 67 | # Convert to mono 68 | if y.size(0) > 1: 69 | y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time] 70 | 71 | # resample audio to target sample_rate 72 | if sr != target_sr: 73 | y = torchaudio.functional.resample(y, sr, target_sr) 74 | return y -------------------------------------------------------------------------------- /utils/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | 7 | def continue_training(checkpoint_path, model: DDP, optimizer: optim.Optimizer) -> int: 8 | """load the latest checkpoints and optimizers""" 9 | model_dict = {} 10 | optimizer_dict = {} 11 | 12 | # globt all the checkpoints in the directory 13 | for file in os.listdir(checkpoint_path): 14 | if file.endswith(".pt") and '_' in file: 15 | name, epoch_str = file.rsplit('_', 1) 16 | epoch = int(epoch_str.split('.')[0]) 17 | 18 | if name.startswith("checkpoint"): 19 | model_dict[epoch] = file 20 | elif name.startswith("optimizer"): 21 | optimizer_dict[epoch] = file 22 | 23 | # get the largest epoch 24 | common_epochs = set(model_dict.keys()) & set(optimizer_dict.keys()) 25 | if common_epochs: 26 | max_epoch = max(common_epochs) 27 | model_path = os.path.join(checkpoint_path, model_dict[max_epoch]) 28 | optimizer_path = os.path.join(checkpoint_path, optimizer_dict[max_epoch]) 29 | 30 | # load model and optimizer 31 | model.module.load_state_dict(torch.load(model_path, map_location='cpu')) 32 | optimizer.load_state_dict(torch.load(optimizer_path, map_location='cpu')) 33 | 34 | print(f'resume model and optimizer from {max_epoch} epoch') 35 | return max_epoch + 1 36 | 37 | else: 38 | # load pretrained checkpoint 39 | if model_dict: 40 | model_path = os.path.join(checkpoint_path, model_dict[max(model_dict.keys())]) 41 | model.module.load_state_dict(torch.load(model_path, map_location='cpu')) 42 | 43 | return 0 -------------------------------------------------------------------------------- /utils/mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # copied from https://github.com/jaywalnut310/vits/blob/main/commons.py#L121 4 | def sequence_mask(length: torch.Tensor, max_length: int = None) -> torch.Tensor: 5 | if max_length is None: 6 | max_length = length.max() 7 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 8 | return x.unsqueeze(0) < length.unsqueeze(1) -------------------------------------------------------------------------------- /vocoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/__init__.py -------------------------------------------------------------------------------- /vocoders/ffgan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/ffgan/__init__.py -------------------------------------------------------------------------------- /vocoders/ffgan/backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | # DropPath copied from timm library 7 | def drop_path( 8 | x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True 9 | ): 10 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 11 | 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | 18 | """ # noqa: E501 19 | 20 | if drop_prob == 0.0 or not training: 21 | return x 22 | keep_prob = 1 - drop_prob 23 | shape = (x.shape[0],) + (1,) * ( 24 | x.ndim - 1 25 | ) # work with diff dim tensors, not just 2D ConvNets 26 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 27 | if keep_prob > 0.0 and scale_by_keep: 28 | random_tensor.div_(keep_prob) 29 | return x * random_tensor 30 | 31 | 32 | class DropPath(nn.Module): 33 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501 34 | 35 | def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): 36 | super(DropPath, self).__init__() 37 | self.drop_prob = drop_prob 38 | self.scale_by_keep = scale_by_keep 39 | 40 | def forward(self, x): 41 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 42 | 43 | def extra_repr(self): 44 | return f"drop_prob={round(self.drop_prob,3):0.3f}" 45 | 46 | 47 | class LayerNorm(nn.Module): 48 | r"""LayerNorm that supports two data formats: channels_last (default) or channels_first. 49 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 50 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs 51 | with shape (batch_size, channels, height, width). 52 | """ # noqa: E501 53 | 54 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): 55 | super().__init__() 56 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 57 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 58 | self.eps = eps 59 | self.data_format = data_format 60 | if self.data_format not in ["channels_last", "channels_first"]: 61 | raise NotImplementedError 62 | self.normalized_shape = (normalized_shape,) 63 | 64 | def forward(self, x): 65 | if self.data_format == "channels_last": 66 | return F.layer_norm( 67 | x, self.normalized_shape, self.weight, self.bias, self.eps 68 | ) 69 | elif self.data_format == "channels_first": 70 | u = x.mean(1, keepdim=True) 71 | s = (x - u).pow(2).mean(1, keepdim=True) 72 | x = (x - u) / torch.sqrt(s + self.eps) 73 | x = self.weight[:, None] * x + self.bias[:, None] 74 | return x 75 | 76 | 77 | # ConvNeXt Block copied from https://github.com/fishaudio/fish-diffusion/blob/main/fish_diffusion/modules/convnext.py 78 | class ConvNeXtBlock(nn.Module): 79 | r"""ConvNeXt Block. There are two equivalent implementations: 80 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 81 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 82 | We use (2) as we find it slightly faster in PyTorch 83 | 84 | Args: 85 | dim (int): Number of input channels. 86 | drop_path (float): Stochastic depth rate. Default: 0.0 87 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 88 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. 89 | kernel_size (int): Kernel size for depthwise conv. Default: 7. 90 | dilation (int): Dilation for depthwise conv. Default: 1. 91 | """ # noqa: E501 92 | 93 | def __init__( 94 | self, 95 | dim: int, 96 | drop_path: float = 0.0, 97 | layer_scale_init_value: float = 1e-6, 98 | mlp_ratio: float = 4.0, 99 | kernel_size: int = 7, 100 | dilation: int = 1, 101 | ): 102 | super().__init__() 103 | 104 | self.dwconv = nn.Conv1d( 105 | dim, 106 | dim, 107 | kernel_size=kernel_size, 108 | padding=int(dilation * (kernel_size - 1) / 2), 109 | groups=dim, 110 | ) # depthwise conv 111 | self.norm = LayerNorm(dim, eps=1e-6) 112 | self.pwconv1 = nn.Linear( 113 | dim, int(mlp_ratio * dim) 114 | ) # pointwise/1x1 convs, implemented with linear layers 115 | self.act = nn.GELU() 116 | self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim) 117 | self.gamma = ( 118 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 119 | if layer_scale_init_value > 0 120 | else None 121 | ) 122 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 123 | 124 | def forward(self, x, apply_residual: bool = True): 125 | input = x 126 | 127 | x = self.dwconv(x) 128 | x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C) 129 | x = self.norm(x) 130 | x = self.pwconv1(x) 131 | x = self.act(x) 132 | x = self.pwconv2(x) 133 | 134 | if self.gamma is not None: 135 | x = self.gamma * x 136 | 137 | x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L) 138 | x = self.drop_path(x) 139 | 140 | if apply_residual: 141 | x = input + x 142 | 143 | return x 144 | 145 | 146 | class ConvNeXtEncoder(nn.Module): 147 | def __init__( 148 | self, 149 | input_channels: int = 3, 150 | depths: list[int] = [3, 3, 9, 3], 151 | dims: list[int] = [96, 192, 384, 768], 152 | drop_path_rate: float = 0.0, 153 | layer_scale_init_value: float = 1e-6, 154 | kernel_size: int = 7, 155 | ): 156 | super().__init__() 157 | assert len(depths) == len(dims) 158 | 159 | self.downsample_layers = nn.ModuleList() 160 | stem = nn.Sequential( 161 | nn.Conv1d( 162 | input_channels, 163 | dims[0], 164 | kernel_size=kernel_size, 165 | padding=kernel_size // 2, 166 | padding_mode="zeros", 167 | ), 168 | LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), 169 | ) 170 | self.downsample_layers.append(stem) 171 | 172 | for i in range(len(depths) - 1): 173 | mid_layer = nn.Sequential( 174 | LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), 175 | nn.Conv1d(dims[i], dims[i + 1], kernel_size=1), 176 | ) 177 | self.downsample_layers.append(mid_layer) 178 | 179 | self.stages = nn.ModuleList() 180 | dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 181 | 182 | cur = 0 183 | for i in range(len(depths)): 184 | stage = nn.Sequential( 185 | *[ 186 | ConvNeXtBlock( 187 | dim=dims[i], 188 | drop_path=dp_rates[cur + j], 189 | layer_scale_init_value=layer_scale_init_value, 190 | kernel_size=kernel_size, 191 | ) 192 | for j in range(depths[i]) 193 | ] 194 | ) 195 | self.stages.append(stage) 196 | cur += depths[i] 197 | 198 | self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") 199 | self.apply(self._init_weights) 200 | 201 | def _init_weights(self, m): 202 | if isinstance(m, (nn.Conv1d, nn.Linear)): 203 | nn.init.trunc_normal_(m.weight, std=0.02) 204 | nn.init.constant_(m.bias, 0) 205 | 206 | def forward( 207 | self, 208 | x: torch.Tensor, 209 | ) -> torch.Tensor: 210 | for i in range(len(self.downsample_layers)): 211 | x = self.downsample_layers[i](x) 212 | x = self.stages[i](x) 213 | 214 | return self.norm(x) 215 | -------------------------------------------------------------------------------- /vocoders/ffgan/head.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from math import prod 3 | from typing import Callable 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import Conv1d 10 | from torch.nn.utils.parametrizations import weight_norm 11 | from torch.nn.utils.parametrize import remove_parametrizations 12 | from torch.utils.checkpoint import checkpoint 13 | 14 | 15 | def init_weights(m, mean=0.0, std=0.01): 16 | classname = m.__class__.__name__ 17 | if classname.find("Conv") != -1: 18 | m.weight.data.normal_(mean, std) 19 | 20 | 21 | def get_padding(kernel_size, dilation=1): 22 | return (kernel_size * dilation - dilation) // 2 23 | 24 | 25 | class ResBlock1(torch.nn.Module): 26 | def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): 27 | super().__init__() 28 | 29 | self.convs1 = nn.ModuleList( 30 | [ 31 | weight_norm( 32 | Conv1d( 33 | channels, 34 | channels, 35 | kernel_size, 36 | 1, 37 | dilation=dilation[0], 38 | padding=get_padding(kernel_size, dilation[0]), 39 | ) 40 | ), 41 | weight_norm( 42 | Conv1d( 43 | channels, 44 | channels, 45 | kernel_size, 46 | 1, 47 | dilation=dilation[1], 48 | padding=get_padding(kernel_size, dilation[1]), 49 | ) 50 | ), 51 | weight_norm( 52 | Conv1d( 53 | channels, 54 | channels, 55 | kernel_size, 56 | 1, 57 | dilation=dilation[2], 58 | padding=get_padding(kernel_size, dilation[2]), 59 | ) 60 | ), 61 | ] 62 | ) 63 | self.convs1.apply(init_weights) 64 | 65 | self.convs2 = nn.ModuleList( 66 | [ 67 | weight_norm( 68 | Conv1d( 69 | channels, 70 | channels, 71 | kernel_size, 72 | 1, 73 | dilation=1, 74 | padding=get_padding(kernel_size, 1), 75 | ) 76 | ), 77 | weight_norm( 78 | Conv1d( 79 | channels, 80 | channels, 81 | kernel_size, 82 | 1, 83 | dilation=1, 84 | padding=get_padding(kernel_size, 1), 85 | ) 86 | ), 87 | weight_norm( 88 | Conv1d( 89 | channels, 90 | channels, 91 | kernel_size, 92 | 1, 93 | dilation=1, 94 | padding=get_padding(kernel_size, 1), 95 | ) 96 | ), 97 | ] 98 | ) 99 | self.convs2.apply(init_weights) 100 | 101 | def forward(self, x): 102 | for c1, c2 in zip(self.convs1, self.convs2): 103 | xt = F.silu(x) 104 | xt = c1(xt) 105 | xt = F.silu(xt) 106 | xt = c2(xt) 107 | x = xt + x 108 | return x 109 | 110 | def remove_parametrizations(self): 111 | for conv in self.convs1: 112 | remove_parametrizations(conv) 113 | for conv in self.convs2: 114 | remove_parametrizations(conv) 115 | 116 | 117 | class ParralelBlock(nn.Module): 118 | def __init__( 119 | self, 120 | channels: int, 121 | kernel_sizes: tuple[int] = (3, 7, 11), 122 | dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), 123 | ): 124 | super().__init__() 125 | 126 | assert len(kernel_sizes) == len(dilation_sizes) 127 | 128 | self.blocks = nn.ModuleList() 129 | for k, d in zip(kernel_sizes, dilation_sizes): 130 | self.blocks.append(ResBlock1(channels, k, d)) 131 | 132 | def forward(self, x): 133 | return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0) 134 | 135 | 136 | class HiFiGANGenerator(nn.Module): 137 | def __init__( 138 | self, 139 | *, 140 | hop_length: int = 512, 141 | upsample_rates: tuple[int] = (8, 8, 2, 2, 2), 142 | upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2), 143 | resblock_kernel_sizes: tuple[int] = (3, 7, 11), 144 | resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)), 145 | num_mels: int = 128, 146 | upsample_initial_channel: int = 512, 147 | use_template: bool = True, 148 | pre_conv_kernel_size: int = 7, 149 | post_conv_kernel_size: int = 7, 150 | post_activation: Callable = partial(nn.SiLU, inplace=True), 151 | ): 152 | super().__init__() 153 | 154 | assert ( 155 | prod(upsample_rates) == hop_length 156 | ), f"hop_length must be {prod(upsample_rates)}" 157 | 158 | self.conv_pre = weight_norm( 159 | nn.Conv1d( 160 | num_mels, 161 | upsample_initial_channel, 162 | pre_conv_kernel_size, 163 | 1, 164 | padding=get_padding(pre_conv_kernel_size), 165 | ) 166 | ) 167 | 168 | self.num_upsamples = len(upsample_rates) 169 | self.num_kernels = len(resblock_kernel_sizes) 170 | 171 | self.noise_convs = nn.ModuleList() 172 | self.use_template = use_template 173 | self.ups = nn.ModuleList() 174 | 175 | for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): 176 | c_cur = upsample_initial_channel // (2 ** (i + 1)) 177 | self.ups.append( 178 | weight_norm( 179 | nn.ConvTranspose1d( 180 | upsample_initial_channel // (2**i), 181 | upsample_initial_channel // (2 ** (i + 1)), 182 | k, 183 | u, 184 | padding=(k - u) // 2, 185 | ) 186 | ) 187 | ) 188 | 189 | if not use_template: 190 | continue 191 | 192 | if i + 1 < len(upsample_rates): 193 | stride_f0 = np.prod(upsample_rates[i + 1 :]) 194 | self.noise_convs.append( 195 | Conv1d( 196 | 1, 197 | c_cur, 198 | kernel_size=stride_f0 * 2, 199 | stride=stride_f0, 200 | padding=stride_f0 // 2, 201 | ) 202 | ) 203 | else: 204 | self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) 205 | 206 | self.resblocks = nn.ModuleList() 207 | for i in range(len(self.ups)): 208 | ch = upsample_initial_channel // (2 ** (i + 1)) 209 | self.resblocks.append( 210 | ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes) 211 | ) 212 | 213 | self.activation_post = post_activation() 214 | self.conv_post = weight_norm( 215 | nn.Conv1d( 216 | ch, 217 | 1, 218 | post_conv_kernel_size, 219 | 1, 220 | padding=get_padding(post_conv_kernel_size), 221 | ) 222 | ) 223 | self.ups.apply(init_weights) 224 | self.conv_post.apply(init_weights) 225 | 226 | def forward(self, x, template=None): 227 | x = self.conv_pre(x) 228 | 229 | for i in range(self.num_upsamples): 230 | x = F.silu(x, inplace=True) 231 | x = self.ups[i](x) 232 | 233 | if self.use_template: 234 | x = x + self.noise_convs[i](template) 235 | 236 | if self.training and self.checkpointing: 237 | x = checkpoint( 238 | self.resblocks[i], 239 | x, 240 | use_reentrant=False, 241 | ) 242 | else: 243 | x = self.resblocks[i](x) 244 | 245 | x = self.activation_post(x) 246 | x = self.conv_post(x) 247 | x = torch.tanh(x) 248 | 249 | return x 250 | 251 | def remove_parametrizations(self): 252 | for up in self.ups: 253 | remove_parametrizations(up) 254 | for block in self.resblocks: 255 | block.remove_parametrizations() 256 | remove_parametrizations(self.conv_pre) 257 | remove_parametrizations(self.conv_post) 258 | -------------------------------------------------------------------------------- /vocoders/ffgan/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .backbone import ConvNeXtEncoder 5 | from .head import HiFiGANGenerator 6 | 7 | config_dict = { 8 | "backbone": { 9 | # "input_channels": "${model.num_mels}", 10 | "input_channels": 128, 11 | "depths": [3, 3, 9, 3], 12 | "dims": [128, 256, 384, 512], 13 | "drop_path_rate": 0.2, 14 | "kernel_size": 7, 15 | }, 16 | "head": { 17 | # "hop_length": "${model.hop_length}", 18 | "hop_length": 512, 19 | "upsample_rates": [8, 8, 2, 2, 2], 20 | "upsample_kernel_sizes": [16, 16, 4, 4, 4], 21 | "resblock_kernel_sizes": [3, 7, 11], 22 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 23 | "num_mels": 512, # consistent with the output of the backbone 24 | "upsample_initial_channel": 512, 25 | "use_template": False, 26 | "pre_conv_kernel_size": 13, 27 | "post_conv_kernel_size": 13, 28 | } 29 | } 30 | 31 | # download_link: https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt 32 | class FireflyGANBaseWrapper(nn.Module): 33 | def __init__(self, model_path): 34 | super().__init__() 35 | self.model = FireflyGANBase() 36 | self.model.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu')) 37 | 38 | self.model.eval() 39 | 40 | @ torch.inference_mode() 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | return self.model(x) 43 | 44 | class FireflyGANBase(nn.Module): 45 | def __init__(self): 46 | super().__init__() 47 | self.backbone = ConvNeXtEncoder(**config_dict["backbone"]) 48 | self.head = HiFiGANGenerator(**config_dict["head"]) 49 | 50 | self.head.checkpointing = False 51 | 52 | @ torch.inference_mode() 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | x = self.backbone(x) 55 | x = self.head(x) 56 | 57 | return x.squeeze(1) -------------------------------------------------------------------------------- /vocoders/ffgan/unify.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class UnifyGenerator(nn.Module): 6 | def __init__( 7 | self, 8 | backbone: nn.Module, 9 | head: nn.Module, 10 | vq: nn.Module | None = None, 11 | ): 12 | super().__init__() 13 | 14 | self.backbone = backbone 15 | self.head = head 16 | self.vq = vq 17 | 18 | def forward(self, x: torch.Tensor, template=None) -> torch.Tensor: 19 | x = self.backbone(x) 20 | 21 | if self.vq is not None: 22 | vq_result = self.vq(x) 23 | x = vq_result.z 24 | 25 | x = self.head(x, template=template) 26 | 27 | if x.ndim == 2: 28 | x = x[:, None, :] 29 | 30 | if self.vq is not None: 31 | return x, vq_result 32 | 33 | return x 34 | 35 | def encode(self, x: torch.Tensor) -> torch.Tensor: 36 | if self.vq is None: 37 | raise ValueError("VQ module is not present in the model.") 38 | 39 | x = self.backbone(x) 40 | vq_result = self.vq(x) 41 | return vq_result.codes 42 | 43 | def decode(self, codes: torch.Tensor, template=None) -> torch.Tensor: 44 | if self.vq is None: 45 | raise ValueError("VQ module is not present in the model.") 46 | 47 | x = self.vq.from_codes(codes)[0] 48 | x = self.head(x, template=template) 49 | 50 | if x.ndim == 2: 51 | x = x[:, None, :] 52 | 53 | return x 54 | 55 | def remove_parametrizations(self): 56 | if hasattr(self.backbone, "remove_parametrizations"): 57 | self.backbone.remove_parametrizations() 58 | 59 | if hasattr(self.head, "remove_parametrizations"): 60 | self.head.remove_parametrizations() -------------------------------------------------------------------------------- /vocoders/pretrained/.keep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/pretrained/.keep -------------------------------------------------------------------------------- /vocoders/vocos/README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Vocos for StableTTS 4 | 5 | Modified from the official implementation of [Vocos](https://github.com/gemelo-ai/vocos/tree/main). 6 | 7 |
8 | 9 | ## Introduction 10 | 11 | Vocos is a fast neural vocoder designed to synthesize audio waveforms from acoustic features. Trained using a Generative Adversarial Network (GAN) objective, Vocos can generate waveforms in a single forward pass. Unlike other typical GAN-based vocoders, Vocos does not model audio samples in the time domain. Instead, it generates spectral coefficients, facilitating rapid audio reconstruction through inverse Fourier transform. 12 | 13 | 14 | ## Inference 15 | 16 | For detailed inference instructions, please refer to `inference.ipynb` 17 | 18 | ## Training 19 | 20 | Setting up and training your model with Vocos is straightforward. Follow these steps to get started: 21 | 22 | ### Preparing Your Data 23 | 24 | 1. **Configure Data Settings**: Update the `DataConfig` in `preprocess.py`. Specifically, adjust the audio_dir to point to your collection of audio files. 25 | 26 | 2. **Run Preprocessing**: Run `preprocess.py`. This script will search (glob) for all audio files in the specified directory, resample them to the target sample_rate (modifiable in config.py), and generate a file list for training. 27 | 28 | ### Start training 29 | 30 | 1. **Adjust Training Configuration**: Edit `TrainConfig` in `config.py` to specify the file list path and tweak training hyperparameters to your needs. 31 | 32 | 2. **Start the Training Process**: Launch `train.py` to begin training your model. 33 | 34 | ### Experiment with Configurations 35 | 36 | Feel free to explore and modify settings in `config.py` to modify the hyperparameters of vocos! 37 | 38 | 39 | ## References 40 | 41 | [Vocos](https://github.com/gemelo-ai/vocos/tree/main) -------------------------------------------------------------------------------- /vocoders/vocos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/__init__.py -------------------------------------------------------------------------------- /vocoders/vocos/config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | @dataclass 4 | class MelConfig: 5 | sample_rate: int = 44100 6 | n_fft: int = 2048 7 | win_length: int = 2048 8 | hop_length: int = 512 9 | f_min: float = 0.0 10 | f_max: float = None 11 | pad: int = 0 12 | n_mels: int = 128 13 | center: bool = False 14 | pad_mode: str = "reflect" 15 | mel_scale: str = "slaney" 16 | 17 | def __post_init__(self): 18 | if self.pad == 0: 19 | self.pad = (self.n_fft - self.hop_length) // 2 20 | 21 | @dataclass 22 | class VocosConfig: 23 | input_channels: int = 128 24 | dim: int = 768 25 | intermediate_dim: int = 2048 26 | num_layers: int = 12 27 | 28 | @dataclass 29 | class TrainConfig: 30 | train_dataset_path: str = './filelists/filelist.txt' 31 | test_dataset_path: str = './filelists/filelist.txt' 32 | batch_size: int = 32 33 | learning_rate: float = 1e-4 34 | num_epochs: int = 10000 35 | model_save_path: str = './checkpoints' 36 | log_dir: str = './runs' 37 | log_interval: int = 64 38 | warmup_steps: int = 200 39 | 40 | segment_size = 20480 41 | mel_loss_factor = 15 -------------------------------------------------------------------------------- /vocoders/vocos/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict 3 | import torch 4 | import torchaudio 5 | from torch.utils.data import Dataset 6 | 7 | from utils.audio import LogMelSpectrogram 8 | from config import MelConfig 9 | 10 | class VocosDataset(Dataset): 11 | def __init__(self, filelist_path, segment_size: int, mel_config: MelConfig): 12 | self.filelist_path = filelist_path 13 | self.segment_size = segment_size 14 | self.sample_rate = mel_config.sample_rate 15 | self.mel_extractor = LogMelSpectrogram(**asdict(mel_config)) 16 | 17 | self.filelist = self._load_filelist(filelist_path) 18 | 19 | def _load_filelist(self, filelist_path): 20 | if os.path.isdir(filelist_path): 21 | print('scanning dir to get audio files') 22 | filelist = find_audio_files(filelist_path) 23 | else: 24 | with open(filelist_path, 'r', encoding='utf-8') as f: 25 | filelist = [line.strip() for line in f if os.path.exists(line.strip())] 26 | return filelist 27 | 28 | def __len__(self): 29 | return len(self.filelist) 30 | 31 | def __getitem__(self, idx): 32 | audio = load_and_pad_audio(self.filelist[idx], self.sample_rate, self.segment_size) 33 | start_index = torch.randint(0, audio.size(-1) - self.segment_size + 1, (1,)).item() 34 | audio = audio[:, start_index:start_index + self.segment_size] # shape: [1, segment_size] 35 | mel = self.mel_extractor(audio).squeeze(0) # shape: [n_mels, segment_size // hop_length] 36 | return audio, mel 37 | 38 | def load_and_pad_audio(audio_path, target_sr, segment_size): 39 | y, sr = torchaudio.load(audio_path) 40 | if y.size(0) > 1: 41 | y = y[0, :].unsqueeze(0) 42 | if sr != target_sr: 43 | y = torchaudio.functional.resample(y, sr, target_sr) 44 | if y.size(-1) < segment_size: 45 | y = torch.nn.functional.pad(y, (0, segment_size - y.size(-1)), "constant", 0) 46 | return y 47 | 48 | def find_audio_files(directory): 49 | audio_files = [] 50 | valid_extensions = ('.wav', '.ogg', '.opus', '.mp3', '.flac') 51 | 52 | for root, dirs, files in os.walk(directory): 53 | for file in files: 54 | if file.endswith(valid_extensions): 55 | audio_files.append(os.path.join(root, file)) 56 | 57 | return audio_files -------------------------------------------------------------------------------- /vocoders/vocos/inference.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torchaudio\n", 11 | "from IPython.display import Audio, display\n", 12 | "\n", 13 | "from models.model import Vocos\n", 14 | "from utils.audio import LogMelSpectrogram\n", 15 | "from config import MelConfig, VocosConfig\n", 16 | "\n", 17 | "from pathlib import Path\n", 18 | "from dataclasses import asdict\n", 19 | "import random\n", 20 | "\n", 21 | "def load_and_resample_audio(audio_path, target_sr):\n", 22 | " y, sr = torchaudio.load(audio_path)\n", 23 | " if y.size(0) > 1:\n", 24 | " y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time]\n", 25 | " if sr != target_sr:\n", 26 | " y = torchaudio.functional.resample(y, sr, target_sr)\n", 27 | " return y\n", 28 | "\n", 29 | "device = 'cpu'\n", 30 | "\n", 31 | "mel_config = MelConfig()\n", 32 | "vocos_config = VocosConfig()\n", 33 | "\n", 34 | "mel_extractor = LogMelSpectrogram(**asdict(mel_config))\n", 35 | "model = Vocos(vocos_config, mel_config).to(device)\n", 36 | "model.load_state_dict(torch.load('./checkpoints/generator_0.pt', map_location='cpu'))\n", 37 | "model.eval()\n", 38 | "\n", 39 | "audio_paths = list(Path('./audios').rglob('*.wav'))" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "audio_path = random.choice(audio_paths)\n", 49 | "with torch.inference_mode():\n", 50 | " audio = load_and_resample_audio(audio_path, mel_config.sample_rate).to(device)\n", 51 | " mel = mel_extractor(audio)\n", 52 | " recon_audio = model(mel)\n", 53 | "display(Audio(audio, rate=mel_config.sample_rate))\n", 54 | "display(Audio(recon_audio, rate=mel_config.sample_rate))" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "lxn_vits", 61 | "language": "python", 62 | "name": "python3" 63 | }, 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.12.4" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 2 79 | } 80 | -------------------------------------------------------------------------------- /vocoders/vocos/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/models/__init__.py -------------------------------------------------------------------------------- /vocoders/vocos/models/backbone.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from .module import ConvNeXtBlock 7 | 8 | class VocosBackbone(nn.Module): 9 | """ 10 | Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization 11 | 12 | Args: 13 | input_channels (int): Number of input features channels. 14 | dim (int): Hidden dimension of the model. 15 | intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. 16 | num_layers (int): Number of ConvNeXtBlock layers. 17 | layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_channels: int, 23 | dim: int, 24 | intermediate_dim: int, 25 | num_layers: int, 26 | layer_scale_init_value: Optional[float] = None, 27 | ): 28 | super().__init__() 29 | self.input_channels = input_channels 30 | self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) 31 | self.norm = nn.LayerNorm(dim, eps=1e-6) 32 | layer_scale_init_value = layer_scale_init_value or 1 / num_layers 33 | self.convnext = nn.ModuleList( 34 | [ 35 | ConvNeXtBlock( 36 | dim=dim, 37 | intermediate_dim=intermediate_dim, 38 | layer_scale_init_value=layer_scale_init_value, 39 | ) 40 | for _ in range(num_layers) 41 | ] 42 | ) 43 | self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) 44 | self.apply(self._init_weights) 45 | 46 | def _init_weights(self, m): 47 | if isinstance(m, (nn.Conv1d, nn.Linear)): 48 | nn.init.trunc_normal_(m.weight, std=0.02) 49 | nn.init.constant_(m.bias, 0) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | x = self.embed(x) 53 | x = self.norm(x.transpose(1, 2)).transpose(1, 2) 54 | for conv_block in self.convnext: 55 | x = conv_block(x) 56 | x = self.final_layer_norm(x.transpose(1, 2)) 57 | return x -------------------------------------------------------------------------------- /vocoders/vocos/models/discriminator.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | from torch import Tensor 6 | from torch.nn import Conv2d 7 | from torch.nn.utils.parametrizations import weight_norm 8 | from torchaudio.transforms import Spectrogram 9 | 10 | 11 | class MultiPeriodDiscriminator(nn.Module): 12 | def __init__(self, periods: Tuple[int, ...] = (2, 3, 5, 7, 11)): 13 | super().__init__() 14 | self.discriminators = nn.ModuleList([DiscriminatorP(period=p) for p in periods]) 15 | 16 | def forward(self, y: Tensor, y_hat: Tensor): 17 | y_d_rs = [] 18 | y_d_gs = [] 19 | fmap_rs = [] 20 | fmap_gs = [] 21 | for d in self.discriminators: 22 | y_d_r, fmap_r = d(y) 23 | y_d_g, fmap_g = d(y_hat) 24 | y_d_rs.append(y_d_r) 25 | fmap_rs.append(fmap_r) 26 | y_d_gs.append(y_d_g) 27 | fmap_gs.append(fmap_g) 28 | 29 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 30 | 31 | 32 | class DiscriminatorP(nn.Module): 33 | def __init__( 34 | self, 35 | period: int, 36 | in_channels: int = 1, 37 | kernel_size: int = 5, 38 | stride: int = 3, 39 | lrelu_slope: float = 0.1, 40 | ): 41 | super().__init__() 42 | self.period = period 43 | self.convs = nn.ModuleList( 44 | [ 45 | weight_norm(Conv2d(in_channels, 32, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 46 | weight_norm(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 47 | weight_norm(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 48 | weight_norm(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(kernel_size // 2, 0))), 49 | weight_norm(Conv2d(1024, 1024, (kernel_size, 1), (1, 1), padding=(kernel_size // 2, 0))), 50 | ] 51 | ) 52 | 53 | self.conv_post = weight_norm(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) 54 | self.lrelu_slope = lrelu_slope 55 | 56 | def forward(self, x: Tensor) -> Tuple[Tensor, List[Tensor]]: 57 | fmap = [] 58 | # 1d to 2d 59 | b, c, t = x.shape 60 | if t % self.period != 0: # pad first 61 | n_pad = self.period - (t % self.period) 62 | x = torch.nn.functional.pad(x, (0, n_pad), "reflect") 63 | t = t + n_pad 64 | x = x.view(b, c, t // self.period, self.period) 65 | 66 | for i, l in enumerate(self.convs): 67 | x = l(x) 68 | x = torch.nn.functional.leaky_relu(x, self.lrelu_slope, inplace=True) 69 | if i > 0: 70 | fmap.append(x) 71 | x = self.conv_post(x) 72 | fmap.append(x) 73 | x = torch.flatten(x, 1, -1) 74 | 75 | return x, fmap 76 | 77 | 78 | class MultiResolutionDiscriminator(nn.Module): 79 | def __init__( 80 | self, 81 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512), 82 | ): 83 | """ 84 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. 85 | 86 | Args: 87 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). 88 | """ 89 | 90 | super().__init__() 91 | self.discriminators = nn.ModuleList( 92 | [DiscriminatorR(window_length=w) for w in fft_sizes] 93 | ) 94 | 95 | def forward(self, y: Tensor, y_hat: Tensor) -> Tuple[List[Tensor], List[Tensor], List[List[Tensor]], List[List[Tensor]]]: 96 | y_d_rs = [] 97 | y_d_gs = [] 98 | fmap_rs = [] 99 | fmap_gs = [] 100 | 101 | for d in self.discriminators: 102 | y_d_r, fmap_r = d(x=y) 103 | y_d_g, fmap_g = d(x=y_hat) 104 | y_d_rs.append(y_d_r) 105 | fmap_rs.append(fmap_r) 106 | y_d_gs.append(y_d_g) 107 | fmap_gs.append(fmap_g) 108 | 109 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 110 | 111 | 112 | class DiscriminatorR(nn.Module): 113 | def __init__( 114 | self, 115 | window_length: int, 116 | channels: int = 32, 117 | hop_factor: float = 0.25, 118 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), 119 | ): 120 | super().__init__() 121 | self.window_length = window_length 122 | self.hop_factor = hop_factor 123 | self.spec_fn = Spectrogram( 124 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None 125 | ) 126 | n_fft = window_length // 2 + 1 127 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 128 | self.bands = bands 129 | convs = lambda: nn.ModuleList( 130 | [ 131 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), 132 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 133 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 134 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 135 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), 136 | ] 137 | ) 138 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 139 | 140 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) 141 | 142 | def spectrogram(self, x): 143 | x = x.squeeze(1) 144 | 145 | # x = x - x.mean(dim=-1, keepdims=True) 146 | # # Peak normalize the volume of input audio 147 | # x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 148 | 149 | x = self.spec_fn(x) 150 | x = torch.view_as_real(x) 151 | x = x.permute(0, 3, 2, 1) # b f t c -> b c t f 152 | # Split into bands 153 | x_bands = [x[..., b[0] : b[1]] for b in self.bands] 154 | return x_bands 155 | 156 | def forward(self, x: Tensor): 157 | x_bands = self.spectrogram(x) 158 | fmap = [] 159 | x = [] 160 | for band, stack in zip(x_bands, self.band_convs): 161 | for i, layer in enumerate(stack): 162 | band = layer(band) 163 | band = torch.nn.functional.leaky_relu(band, 0.1, inplace=True) 164 | if i > 0: 165 | fmap.append(band) 166 | x.append(band) 167 | x = torch.cat(x, dim=-1) 168 | x = self.conv_post(x) 169 | fmap.append(x) 170 | 171 | return x, fmap -------------------------------------------------------------------------------- /vocoders/vocos/models/head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ISTFT(nn.Module): 6 | """ 7 | Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with 8 | windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. 9 | See issue: https://github.com/pytorch/pytorch/issues/62323 10 | Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. 11 | The NOLA constraint is met as we trim padded samples anyway. 12 | 13 | Args: 14 | n_fft (int): Size of Fourier transform. 15 | hop_length (int): The distance between neighboring sliding window frames. 16 | win_length (int): The size of window frame and STFT filter. 17 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 18 | """ 19 | 20 | def __init__(self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"): 21 | super().__init__() 22 | if padding not in ["center", "same"]: 23 | raise ValueError("Padding must be 'center' or 'same'.") 24 | self.padding = padding 25 | self.n_fft = n_fft 26 | self.hop_length = hop_length 27 | self.win_length = win_length 28 | window = torch.hann_window(win_length) 29 | self.register_buffer("window", window) 30 | 31 | def forward(self, spec: torch.Tensor) -> torch.Tensor: 32 | """ 33 | Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. 34 | 35 | Args: 36 | spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, 37 | N is the number of frequency bins, and T is the number of time frames. 38 | 39 | Returns: 40 | Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. 41 | """ 42 | if self.padding == "center": 43 | # Fallback to pytorch native implementation 44 | return torch.istft(spec, self.n_fft, self.hop_length, self.win_length, self.window, center=True) 45 | elif self.padding == "same": 46 | pad = (self.win_length - self.hop_length) // 2 47 | else: 48 | raise ValueError("Padding must be 'center' or 'same'.") 49 | 50 | assert spec.dim() == 3, "Expected a 3D tensor as input" 51 | B, N, T = spec.shape 52 | 53 | # Inverse FFT 54 | ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") 55 | ifft = ifft * self.window[None, :, None] 56 | 57 | # Overlap and Add 58 | output_size = (T - 1) * self.hop_length + self.win_length 59 | y = torch.nn.functional.fold( 60 | ifft, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 61 | )[:, 0, 0, pad:-pad] 62 | 63 | # Window envelope 64 | window_sq = self.window.square().expand(1, T, -1).transpose(1, 2) 65 | window_envelope = torch.nn.functional.fold( 66 | window_sq, output_size=(1, output_size), kernel_size=(1, self.win_length), stride=(1, self.hop_length), 67 | ).squeeze()[pad:-pad] 68 | 69 | # Normalize 70 | assert (window_envelope > 1e-11).all() 71 | y = y / window_envelope 72 | 73 | return y 74 | 75 | class ISTFTHead(nn.Module): 76 | """ 77 | ISTFT Head module for predicting STFT complex coefficients. 78 | 79 | Args: 80 | dim (int): Hidden dimension of the model. 81 | n_fft (int): Size of Fourier transform. 82 | hop_length (int): The distance between neighboring sliding window frames, which should align with 83 | the resolution of the input features. 84 | padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". 85 | """ 86 | 87 | def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"): 88 | super().__init__() 89 | out_dim = n_fft + 2 90 | self.out = torch.nn.Linear(dim, out_dim) 91 | self.istft = ISTFT(n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding) 92 | 93 | def forward(self, x: torch.Tensor) -> torch.Tensor: 94 | """ 95 | Forward pass of the ISTFTHead module. 96 | 97 | Args: 98 | x (Tensor): Input tensor of shape (B, L, H), where B is the batch size, 99 | L is the sequence length, and H denotes the model dimension. 100 | 101 | Returns: 102 | Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal. 103 | """ 104 | x = self.out(x).transpose(1, 2) 105 | mag, p = x.chunk(2, dim=1) 106 | mag = torch.exp(mag) 107 | mag = torch.clip(mag, max=1e2) # safeguard to prevent excessively large magnitudes 108 | # wrapping happens here. These two lines produce real and imaginary value 109 | x = torch.cos(p) 110 | y = torch.sin(p) 111 | # recalculating phase here does not produce anything new 112 | # only costs time 113 | # phase = torch.atan2(y, x) 114 | # S = mag * torch.exp(phase * 1j) 115 | # better directly produce the complex value 116 | S = mag * (x + 1j * y) 117 | audio = self.istft(S) 118 | return audio -------------------------------------------------------------------------------- /vocoders/vocos/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from typing import List 4 | from dataclasses import asdict 5 | 6 | from utils.audio import LogMelSpectrogram 7 | from config import MelConfig 8 | 9 | # Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license. 10 | class MultiScaleMelSpectrogramLoss(nn.Module): 11 | def __init__(self, n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320], window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048]): 12 | super().__init__() 13 | assert len(n_mels) == len(window_lengths), "n_mels and window_lengths must have the same length" 14 | self.mel_transforms = nn.ModuleList(self._get_transforms(n_mels, window_lengths)) 15 | self.loss_fn = nn.L1Loss() 16 | 17 | def _get_transforms(self, n_mels, window_lengths): 18 | transforms = [] 19 | for n_mel, win_length in zip(n_mels, window_lengths): 20 | transform = LogMelSpectrogram(**asdict(MelConfig(n_mels=n_mel, n_fft=win_length, win_length=win_length, hop_length=win_length//4))) 21 | transforms.append(transform) 22 | return transforms 23 | 24 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 25 | return sum(self.loss_fn(mel_transform(x), mel_transform(y)) for mel_transform in self.mel_transforms) 26 | 27 | class SingleScaleMelSpectrogramLoss(nn.Module): 28 | def __init__(self): 29 | super().__init__() 30 | self.mel_transform = LogMelSpectrogram(**asdict(MelConfig())) 31 | self.loss_fn = nn.L1Loss() 32 | print('using single mel loss') 33 | 34 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 35 | return self.loss_fn(self.mel_transform(x), self.mel_transform(y)) 36 | 37 | def feature_loss(fmap_r, fmap_g): 38 | loss = 0 39 | for dr, dg in zip(fmap_r, fmap_g): 40 | for rl, gl in zip(dr, dg): 41 | loss += torch.mean(torch.abs(rl - gl)) 42 | 43 | return loss*2 44 | 45 | def discriminator_loss(disc_real_outputs, disc_generated_outputs): 46 | loss = 0 47 | r_losses = [] 48 | g_losses = [] 49 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 50 | r_loss = torch.mean((1-dr)**2) 51 | g_loss = torch.mean(dg**2) 52 | loss += (r_loss + g_loss) 53 | r_losses.append(r_loss.item()) 54 | g_losses.append(g_loss.item()) 55 | 56 | return loss, r_losses, g_losses 57 | 58 | def generator_loss(disc_outputs): 59 | loss = 0 60 | gen_losses = [] 61 | for dg in disc_outputs: 62 | l = torch.mean((1-dg)**2) 63 | gen_losses.append(l) 64 | loss += l 65 | 66 | return loss, gen_losses -------------------------------------------------------------------------------- /vocoders/vocos/models/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, asdict 2 | 3 | import torch 4 | from torch import nn 5 | from torch import Tensor 6 | 7 | from .head import ISTFTHead 8 | from .backbone import VocosBackbone 9 | from config import MelConfig, VocosConfig 10 | 11 | class Vocos(nn.Module): 12 | def __init__(self, vocos_config: VocosConfig, mel_config: MelConfig): 13 | super().__init__() 14 | self.backbone = VocosBackbone(**asdict(vocos_config)) 15 | self.head = ISTFTHead(vocos_config.dim, mel_config.n_fft, mel_config.hop_length) 16 | 17 | def forward(self, x: Tensor) -> Tensor: 18 | x = self.backbone(x) 19 | x = self.head(x) 20 | return x 21 | -------------------------------------------------------------------------------- /vocoders/vocos/models/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ConvNeXtBlock(nn.Module): 6 | """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. 7 | 8 | Args: 9 | dim (int): Number of input channels. 10 | intermediate_dim (int): Dimensionality of the intermediate layer. 11 | layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. 12 | Defaults to None. 13 | """ 14 | 15 | def __init__( 16 | self, 17 | dim: int, 18 | intermediate_dim: int, 19 | layer_scale_init_value: float, 20 | ): 21 | super().__init__() 22 | self.dwconv = nn.Conv1d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv 23 | self.norm = nn.LayerNorm(dim, eps=1e-6) 24 | self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers 25 | self.act = nn.GELU() 26 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 27 | self.gamma = ( 28 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 29 | if layer_scale_init_value > 0 30 | else None 31 | ) 32 | 33 | def forward(self, x: torch.Tensor) -> torch.Tensor: 34 | residual = x 35 | x = self.dwconv(x) 36 | x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) 37 | x = self.norm(x) 38 | x = self.pwconv1(x) 39 | x = self.act(x) 40 | x = self.pwconv2(x) 41 | if self.gamma is not None: 42 | x = self.gamma * x 43 | x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) 44 | 45 | x = residual + x 46 | return x 47 | -------------------------------------------------------------------------------- /vocoders/vocos/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import concurrent.futures 3 | 4 | from tqdm import tqdm 5 | from dataclasses import dataclass 6 | 7 | @dataclass 8 | class DataConfig: 9 | audio_dirs = ['./datasets'] # paths to audios 10 | filelist_path = './filelists/filelist.txt' # path to save filelist 11 | audio_formats = ('.wav', '.ogg', '.opus', '.mp3', '.flac') 12 | 13 | data_config = DataConfig() 14 | 15 | filelist_path = data_config.filelist_path 16 | 17 | os.makedirs(os.path.dirname(filelist_path), exist_ok=True) 18 | 19 | def find_audio_files(directory) -> list: 20 | audio_files = [] 21 | valid_extensions = data_config.audio_formats 22 | 23 | for root, dirs, files in tqdm(os.walk(directory)): 24 | audio_files.extend(os.path.join(root, file) for file in files if file.endswith(valid_extensions)) 25 | 26 | return audio_files 27 | 28 | 29 | def main(): 30 | results = [] 31 | 32 | with concurrent.futures.ProcessPoolExecutor(max_workers=4) as executor: 33 | futures = [executor.submit(find_audio_files, audio_dir) for audio_dir in data_config.audio_dirs] 34 | 35 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 36 | results.extend(future.result()) 37 | 38 | # save filelist 39 | with open(filelist_path, 'w', encoding='utf-8') as f: 40 | f.writelines(f"{result}\n" for result in results) 41 | 42 | print(f"filelist has been saved to {filelist_path}") 43 | 44 | if __name__ == '__main__': 45 | main() -------------------------------------------------------------------------------- /vocoders/vocos/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorboard 2 | tqdm -------------------------------------------------------------------------------- /vocoders/vocos/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1' 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torch.distributed as dist 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from torch.utils.data import DataLoader 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | from tqdm import tqdm 14 | import itertools 15 | from dataclasses import asdict 16 | 17 | from models.model import Vocos 18 | from dataset import VocosDataset 19 | from models.discriminator import MultiPeriodDiscriminator, MultiResolutionDiscriminator 20 | from models.loss import feature_loss, generator_loss, discriminator_loss, MultiScaleMelSpectrogramLoss, SingleScaleMelSpectrogramLoss 21 | from config import MelConfig, VocosConfig, TrainConfig 22 | from utils.scheduler import get_cosine_schedule_with_warmup 23 | from utils.load import continue_training 24 | 25 | torch.backends.cudnn.benchmark = True 26 | 27 | def setup(rank, world_size): 28 | os.environ['MASTER_ADDR'] = 'localhost' 29 | os.environ['MASTER_PORT'] = '12345' 30 | dist.init_process_group("gloo" if os.name == "nt" else "nccl", rank=rank, world_size=world_size) 31 | 32 | def cleanup(): 33 | dist.destroy_process_group() 34 | 35 | def _init_config(vocos_config: VocosConfig, mel_config: MelConfig, train_config: TrainConfig): 36 | if vocos_config.input_channels != mel_config.n_mels: 37 | raise ValueError("input_channels and n_mels must be equal.") 38 | 39 | if not os.path.exists(train_config.model_save_path): 40 | print(f'Creating {train_config.model_save_path}') 41 | os.makedirs(train_config.model_save_path, exist_ok=True) 42 | 43 | def train(rank, world_size): 44 | setup(rank, world_size) 45 | torch.cuda.set_device(rank) 46 | 47 | vocos_config = VocosConfig() 48 | mel_config = MelConfig() 49 | train_config = TrainConfig() 50 | 51 | _init_config(vocos_config, mel_config, train_config) 52 | 53 | generator = Vocos(vocos_config, mel_config).to(rank) 54 | mpd = MultiPeriodDiscriminator().to(rank) 55 | mrd = MultiResolutionDiscriminator().to(rank) 56 | loss_fn = MultiScaleMelSpectrogramLoss().to(rank) 57 | if rank == 0: 58 | print(f"Generator params: {sum(p.numel() for p in generator.parameters()) / 1e6}") 59 | print(f"Discriminator mpd params: {sum(p.numel() for p in mpd.parameters()) / 1e6}") 60 | print(f"Discriminator mrd params: {sum(p.numel() for p in mrd.parameters()) / 1e6}") 61 | 62 | generator = DDP(generator, device_ids=[rank]) 63 | mpd = DDP(mpd, device_ids=[rank]) 64 | mrd = DDP(mrd, device_ids=[rank]) 65 | 66 | train_dataset = VocosDataset(train_config.train_dataset_path, train_config.segment_size, mel_config) 67 | train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank) 68 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=train_config.batch_size, num_workers=4, pin_memory=False, persistent_workers=True) 69 | 70 | if rank == 0: 71 | writer = SummaryWriter(train_config.log_dir) 72 | 73 | optimizer_g = optim.AdamW(generator.parameters(), lr=train_config.learning_rate) 74 | optimizer_d = optim.AdamW(itertools.chain(mpd.parameters(), mrd.parameters()), lr=train_config.learning_rate) 75 | scheduler_g = get_cosine_schedule_with_warmup(optimizer_g, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader)) 76 | scheduler_d = get_cosine_schedule_with_warmup(optimizer_d, num_warmup_steps=int(train_config.warmup_steps), num_training_steps=train_config.num_epochs * len(train_dataloader)) 77 | 78 | # load latest checkpoints if possible 79 | current_epoch = continue_training(train_config.model_save_path, generator, mpd, mrd, optimizer_d, optimizer_g) 80 | 81 | generator.train() 82 | mpd.train() 83 | mrd.train() 84 | for epoch in range(current_epoch, train_config.num_epochs): # loop over the train_dataset multiple times 85 | train_dataloader.sampler.set_epoch(epoch) 86 | if rank == 0: 87 | dataloader = tqdm(train_dataloader) 88 | else: 89 | dataloader = train_dataloader 90 | 91 | for batch_idx, datas in enumerate(dataloader): 92 | datas = [data.to(rank, non_blocking=True) for data in datas] 93 | audios, mels = datas 94 | audios_fake = generator(mels).unsqueeze(1) # shape: [batch_size, 1, segment_size] 95 | optimizer_d.zero_grad() 96 | 97 | # MPD 98 | y_df_hat_r, y_df_hat_g, _, _ = mpd(audios,audios_fake.detach()) 99 | loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g) 100 | 101 | # MRD 102 | y_ds_hat_r, y_ds_hat_g, _, _ = mrd(audios,audios_fake.detach()) 103 | loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g) 104 | 105 | loss_disc_all = loss_disc_s + loss_disc_f 106 | loss_disc_all.backward() 107 | 108 | grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), 1000) 109 | grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), 1000) 110 | optimizer_d.step() 111 | scheduler_d.step() 112 | 113 | # generator 114 | optimizer_g.zero_grad() 115 | loss_mel = loss_fn(audios, audios_fake) * train_config.mel_loss_factor 116 | 117 | # MPD loss 118 | y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(audios,audios_fake) 119 | loss_fm_f = feature_loss(fmap_f_r, fmap_f_g) 120 | loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g) 121 | 122 | # MRD loss 123 | y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = mrd(audios,audios_fake) 124 | loss_fm_s = feature_loss(fmap_s_r, fmap_s_g) 125 | loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g) 126 | 127 | loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel 128 | loss_gen_all.backward() 129 | 130 | grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), 1000) 131 | optimizer_g.step() 132 | scheduler_g.step() 133 | 134 | if rank == 0 and batch_idx % train_config.log_interval == 0: 135 | steps = epoch * len(dataloader) + batch_idx 136 | writer.add_scalar("training/gen_loss_total", loss_gen_all, steps) 137 | writer.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps) 138 | writer.add_scalar("training/gen_loss_mpd", loss_gen_f.item(), steps) 139 | writer.add_scalar("training/disc_loss_mpd", loss_disc_f.item(), steps) 140 | writer.add_scalar("training/fm_loss_mrd", loss_fm_s.item(), steps) 141 | writer.add_scalar("training/gen_loss_mrd", loss_gen_s.item(), steps) 142 | writer.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps) 143 | writer.add_scalar("training/mel_loss", loss_mel.item(), steps) 144 | writer.add_scalar("grad_norm/grad_norm_mpd", grad_norm_mpd, steps) 145 | writer.add_scalar("grad_norm/grad_norm_mrd", grad_norm_mrd, steps) 146 | writer.add_scalar("grad_norm/grad_norm_g", grad_norm_g, steps) 147 | writer.add_scalar("learning_rate/learning_rate_d", scheduler_d.get_last_lr()[0], steps) 148 | writer.add_scalar("learning_rate/learning_rate_g", scheduler_g.get_last_lr()[0], steps) 149 | 150 | if rank == 0: 151 | torch.save(generator.module.state_dict(), os.path.join(train_config.model_save_path, f'generator_{epoch}.pt')) 152 | torch.save(mpd.module.state_dict(), os.path.join(train_config.model_save_path, f'mpd_{epoch}.pt')) 153 | torch.save(mrd.module.state_dict(), os.path.join(train_config.model_save_path, f'mrd_{epoch}.pt')) 154 | torch.save(optimizer_d.state_dict(), os.path.join(train_config.model_save_path, f'optimizerd_{epoch}.pt')) 155 | torch.save(optimizer_g.state_dict(), os.path.join(train_config.model_save_path, f'optimizerg_{epoch}.pt')) 156 | print(f"Rank {rank}, Epoch {epoch}, Loss {loss_gen_all.item()}") 157 | 158 | cleanup() 159 | 160 | torch.set_num_threads(1) 161 | torch.set_num_interop_threads(1) 162 | 163 | if __name__ == "__main__": 164 | world_size = torch.cuda.device_count() 165 | torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size) -------------------------------------------------------------------------------- /vocoders/vocos/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KdaiP/StableTTS/71dfa4138c511df8e0aedf444df98c6baa44cad4/vocoders/vocos/utils/__init__.py -------------------------------------------------------------------------------- /vocoders/vocos/utils/audio.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torchaudio 5 | 6 | class LinearSpectrogram(nn.Module): 7 | def __init__(self, n_fft, win_length, hop_length, pad, center, pad_mode): 8 | super().__init__() 9 | 10 | self.n_fft = n_fft 11 | self.win_length = win_length 12 | self.hop_length = hop_length 13 | self.pad = pad 14 | self.center = center 15 | self.pad_mode = pad_mode 16 | 17 | self.register_buffer("window", torch.hann_window(win_length)) 18 | 19 | def forward(self, waveform: Tensor) -> Tensor: 20 | if waveform.ndim == 3: 21 | waveform = waveform.squeeze(1) 22 | waveform = torch.nn.functional.pad(waveform.unsqueeze(1), (self.pad, self.pad), self.pad_mode).squeeze(1) 23 | spec = torch.stft(waveform, self.n_fft, self.hop_length, self.win_length, self.window, self.center, self.pad_mode, False, True, True) 24 | spec = torch.view_as_real(spec) 25 | spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) 26 | return spec 27 | 28 | 29 | class LogMelSpectrogram(nn.Module): 30 | def __init__(self, sample_rate, n_fft, win_length, hop_length, f_min, f_max, pad, n_mels, center, pad_mode, mel_scale): 31 | super().__init__() 32 | self.sample_rate = sample_rate 33 | self.n_fft = n_fft 34 | self.win_length = win_length 35 | self.hop_length = hop_length 36 | self.f_min = f_min 37 | self.f_max = f_max 38 | self.pad = pad 39 | self.n_mels = n_mels 40 | self.center = center 41 | self.pad_mode = pad_mode 42 | self.mel_scale = mel_scale 43 | 44 | self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, pad, center, pad_mode) 45 | self.mel_scale = torchaudio.transforms.MelScale(n_mels, sample_rate, f_min, f_max, (n_fft//2)+1, mel_scale, mel_scale) 46 | 47 | def compress(self, x: Tensor) -> Tensor: 48 | return torch.log(torch.clamp(x, min=1e-5)) 49 | 50 | def decompress(self, x: Tensor) -> Tensor: 51 | return torch.exp(x) 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | linear_spec = self.spectrogram(x) 55 | x = self.mel_scale(linear_spec) 56 | x = self.compress(x) 57 | return x 58 | 59 | def load_and_resample_audio(audio_path, target_sr, device='cpu') -> Tensor: 60 | try: 61 | y, sr = torchaudio.load(audio_path) 62 | except Exception as e: 63 | print(str(e)) 64 | return None 65 | 66 | y.to(device) 67 | # Convert to mono 68 | if y.size(0) > 1: 69 | y = y[0, :].unsqueeze(0) # shape: [2, time] -> [time] -> [1, time] 70 | 71 | # resample audio to target sample_rate 72 | if sr != target_sr: 73 | y = torchaudio.functional.resample(y, sr, target_sr) 74 | return y -------------------------------------------------------------------------------- /vocoders/vocos/utils/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | 7 | def continue_training(checkpoint_path, generator: DDP, mpd: DDP, mrd: DDP, optimizer_d: optim.Optimizer, optimizer_g: optim.Optimizer) -> int: 8 | """load the latest checkpoints and optimizers""" 9 | generator_dict = {} 10 | mpd_dict = {} 11 | mrd_dict = {} 12 | optimizer_d_dict = {} 13 | optimizer_g_dict = {} 14 | 15 | # globt all the checkpoints in the directory 16 | for file in os.listdir(checkpoint_path): 17 | if file.endswith(".pt"): 18 | name, epoch_str = file.rsplit('_', 1) 19 | epoch = int(epoch_str.split('.')[0]) 20 | 21 | if name.startswith("generator"): 22 | generator_dict[epoch] = file 23 | elif name.startswith("mpd"): 24 | mpd_dict[epoch] = file 25 | elif name.startswith("mrd"): 26 | mrd_dict[epoch] = file 27 | elif name.startswith("optimizerd"): 28 | optimizer_d_dict[epoch] = file 29 | elif name.startswith("optimizerg"): 30 | optimizer_g_dict[epoch] = file 31 | 32 | # get the largest epoch 33 | common_epochs = set(generator_dict.keys()) & set(mpd_dict.keys()) & set(mrd_dict.keys()) & set(optimizer_d_dict.keys()) & set(optimizer_g_dict.keys()) 34 | if common_epochs: 35 | max_epoch = max(common_epochs) 36 | generator_path = os.path.join(checkpoint_path, generator_dict[max_epoch]) 37 | mpd_path = os.path.join(checkpoint_path, mpd_dict[max_epoch]) 38 | mrd_path = os.path.join(checkpoint_path, mrd_dict[max_epoch]) 39 | optimizer_d_path = os.path.join(checkpoint_path, optimizer_d_dict[max_epoch]) 40 | optimizer_g_path = os.path.join(checkpoint_path, optimizer_g_dict[max_epoch]) 41 | 42 | # load model and optimizer 43 | generator.module.load_state_dict(torch.load(generator_path, map_location='cpu')) 44 | mpd.module.load_state_dict(torch.load(mpd_path, map_location='cpu')) 45 | mrd.module.load_state_dict(torch.load(mrd_path, map_location='cpu')) 46 | optimizer_d.load_state_dict(torch.load(optimizer_d_path, map_location='cpu')) 47 | optimizer_g.load_state_dict(torch.load(optimizer_g_path, map_location='cpu')) 48 | 49 | print(f'resume model and optimizer from {max_epoch} epoch') 50 | return max_epoch + 1 51 | 52 | else: 53 | return 0 -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TMPDIR'] = './temps' # avoid the system default temp folder not having access permissions 3 | # os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' # use huggingfacae mirror for users that could not login to huggingface 4 | 5 | import re 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | 9 | import torch 10 | import gradio as gr 11 | 12 | from api import StableTTSAPI 13 | 14 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 15 | 16 | tts_model_path = './checkpoints/checkpoint_0.pt' 17 | vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' 18 | vocoder_type = 'ffgan' 19 | 20 | model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type).to(device) 21 | 22 | @ torch.inference_mode() 23 | def inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg): 24 | text = remove_newlines_after_punctuation(text) 25 | 26 | if language == 'chinese': 27 | text = text.replace(' ', '') 28 | 29 | audio, mel = model.inference(text, ref_audio, language, step, temperature, length_scale, solver, cfg) 30 | 31 | max_val = torch.max(torch.abs(audio)) 32 | if max_val > 1: 33 | audio = audio / max_val 34 | 35 | audio_output = (model.mel_config.sample_rate, (audio.cpu().squeeze(0).numpy() * 32767).astype(np.int16)) # (samplerate, int16 audio) for gr.Audio 36 | mel_output = plot_mel_spectrogram(mel.cpu().squeeze(0).numpy()) # get the plot of mel 37 | 38 | return audio_output, mel_output 39 | 40 | def plot_mel_spectrogram(mel_spectrogram): 41 | plt.close() # prevent memory leak 42 | fig, ax = plt.subplots(figsize=(20, 8)) 43 | ax.imshow(mel_spectrogram, aspect='auto', origin='lower') 44 | plt.axis('off') 45 | fig.subplots_adjust(left=0, right=1, top=1, bottom=0) # remove white edges 46 | return fig 47 | 48 | def remove_newlines_after_punctuation(text): 49 | pattern = r'([,。!?、“”‘’《》【】;:,.!?\'\"<>()\[\]{}])\n' 50 | return re.sub(pattern, r'\1', text) 51 | 52 | def main(): 53 | 54 | # gradio wabui, reference: https://huggingface.co/spaces/fishaudio/fish-speech-1 55 | gui_title = 'StableTTS' 56 | gui_description = """Next-generation TTS model using flow-matching and DiT, inspired by Stable Diffusion 3.""" 57 | example_text = """你指尖跳动的电光,是我永恒不变的信仰。唯我超电磁炮永世长存!""" 58 | 59 | with gr.Blocks(theme=gr.themes.Base()) as demo: 60 | demo.load(None, None, js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', 'light');window.location.search = params.toString();}}") 61 | 62 | with gr.Row(): 63 | with gr.Column(): 64 | gr.Markdown(f"# {gui_title}") 65 | gr.Markdown(gui_description) 66 | 67 | with gr.Row(): 68 | with gr.Column(): 69 | input_text_gr = gr.Textbox( 70 | label="Input Text", 71 | info="Put your text here", 72 | value=example_text, 73 | ) 74 | 75 | ref_audio_gr = gr.Audio( 76 | label="Reference Audio", 77 | type="filepath" 78 | ) 79 | 80 | language_gr = gr.Dropdown( 81 | label='Language', 82 | choices=list(model.supported_languages), 83 | value = 'chinese' 84 | ) 85 | 86 | step_gr = gr.Slider( 87 | label='Step', 88 | minimum=1, 89 | maximum=100, 90 | value=25, 91 | step=1 92 | ) 93 | 94 | temperature_gr = gr.Slider( 95 | label='Temperature', 96 | minimum=0, 97 | maximum=2, 98 | value=1, 99 | ) 100 | 101 | length_scale_gr = gr.Slider( 102 | label='Length_Scale', 103 | minimum=0, 104 | maximum=5, 105 | value=1, 106 | ) 107 | 108 | solver_gr = gr.Dropdown( 109 | label='ODE Solver', 110 | choices=['euler', 'midpoint', 'dopri5', 'rk4', 'implicit_adams', 'bosh3', 'fehlberg2', 'adaptive_heun'], 111 | value = 'dopri5' 112 | ) 113 | 114 | cfg_gr = gr.Slider( 115 | label='CFG', 116 | minimum=0, 117 | maximum=10, 118 | value=3, 119 | ) 120 | 121 | with gr.Column(): 122 | mel_gr = gr.Plot(label="Mel Visual") 123 | audio_gr = gr.Audio(label="Synthesised Audio", autoplay=True) 124 | tts_button = gr.Button("\U0001F3A7 Generate / 合成", elem_id="send-btn", visible=True, variant="primary") 125 | 126 | tts_button.click(inference, [input_text_gr, ref_audio_gr, language_gr, step_gr, temperature_gr, length_scale_gr, solver_gr, cfg_gr], outputs=[audio_gr, mel_gr]) 127 | 128 | demo.queue() 129 | demo.launch(debug=True, show_api=True) 130 | 131 | 132 | if __name__ == '__main__': 133 | main() --------------------------------------------------------------------------------