├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── assets ├── LoadSpeakerModel.png ├── PuncSegment.png ├── SaveSpeakerModel.png ├── SenseVoice.png └── Workflow_FunAudioLLM.png ├── cosyvoice ├── __init__.py ├── bin │ ├── export_jit.py │ ├── export_onnx.py │ ├── inference.py │ └── train.py ├── cli │ ├── __init__.py │ ├── cosyvoice.py │ ├── frontend.py │ └── model.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── processor.py ├── flow │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── hifigan │ ├── f0_predictor.py │ └── generator.py ├── llm │ └── llm.py ├── tokenizer │ ├── assets │ │ └── multilingual_zh_ja_yue_char_del.tiktoken │ └── tokenizer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ └── subsampling.py └── utils │ ├── __init__.py │ ├── class_utils.py │ ├── common.py │ ├── executor.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── mask.py │ ├── scheduler.py │ └── train_utils.py ├── funaudio_utils ├── __init__.py ├── cosyvoice_plus.py ├── download_models.py └── pre.py ├── matcha ├── VERSION ├── __init__.py ├── app.py ├── cli.py ├── data │ ├── __init__.py │ ├── components │ │ └── __init__.py │ └── text_mel_datamodule.py ├── hifigan │ ├── LICENSE │ ├── README.md │ ├── __init__.py │ ├── config.py │ ├── denoiser.py │ ├── env.py │ ├── meldataset.py │ ├── models.py │ └── xutils.py ├── models │ ├── __init__.py │ ├── baselightningmodule.py │ ├── components │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── flow_matching.py │ │ ├── text_encoder.py │ │ └── transformer.py │ └── matcha_tts.py ├── onnx │ ├── __init__.py │ ├── export.py │ └── infer.py ├── text │ ├── __init__.py │ ├── cleaners.py │ ├── numbers.py │ └── symbols.py ├── train.py └── utils │ ├── __init__.py │ ├── audio.py │ ├── generate_data_statistics.py │ ├── instantiators.py │ ├── logging_utils.py │ ├── model.py │ ├── monotonic_align │ ├── __init__.py │ ├── core.pyx │ └── setup.py │ ├── pylogger.py │ ├── rich_utils.py │ └── utils.py ├── nodes ├── __init__.py ├── cosyvoice_nodes.py └── sensevoice_nodes.py ├── pyproject.toml ├── requirements.txt ├── sensevoice ├── __init__.py ├── model.py └── utils │ ├── __init__.py │ ├── export_utils.py │ ├── frontend.py │ ├── infer_utils.py │ └── model_bin.py ├── web └── PUT_WEB_JS_HERE └── workflow └── FunAudioLLM.json /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | # if this is a forked repository. Skipping the workflow. 16 | if: github.event.repository.fork == false 17 | steps: 18 | - name: Check out code 19 | uses: actions/checkout@v4 20 | - name: Publish Custom Node 21 | uses: Comfy-Org/publish-node-action@main 22 | with: 23 | ## Add your own personal access token to your Github Repository secrets and reference it here. 24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 9 | # ComfyUI-FunAudioLLM 10 | Comfyui custom node for [FunAudioLLM](https://funaudiollm.github.io/) include [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) and [SenseVoice](https://github.com/FunAudioLLM/SenseVoice) 11 | 12 | ## Features 13 | 14 | ### CosyVoice 15 | - CosyVoice Version: 2024-10-04 16 | - Support SFT,Zero-shot,Cross-lingual,Instruct 17 | - Support CosyVoice-300M-25Hz in zero-shot and cross-lingual 18 | - Support SFT's 25Hz(unoffical) 19 | -
20 | Save and load speaker model in zero-shot 21 | zh-CN
22 | zh-CN 23 |
24 | 25 | ### SenseVoice 26 | - SenseVoice Version: 2024-10-04 27 | - Support SenseVoice-Small 28 | -
29 | Support Punctuation segment (need turn off use_fast_mode) 30 | zh-CN
31 | zh-CN 32 |
33 | 34 | ## How use 35 | ```bash 36 | apt update 37 | apt install ffmpeg 38 | 39 | ## in ComfyUI/custom_nodes 40 | git clone https://github.com/SpenserCai/ComfyUI-FunAudioLLM 41 | cd ComfyUI-FunAudioLLM 42 | pip install -r requirements.txt 43 | 44 | ``` 45 | 46 | ### Windows 47 | In windows need use conda to install pynini 48 | ```bash 49 | conda install -c conda-forge pynini=2.1.6 50 | pip install -r requirements.txt 51 | 52 | ``` 53 | 54 | ### MacOS 55 | If meet error when you install 56 | ```bash 57 | brew install openfst 58 | export CPPFLAGS="-I/opt/homebrew/include" 59 | export LDFLAGS="-L/opt/homebrew/lib" 60 | pip install -r requirements.txt 61 | ``` 62 | 63 | If your network is unstable, you can pre-download the model from the following sources and place it in the appropriate directory. 64 | 65 | - [CosyVoice-300M](https://modelscope.cn/models/iic/CosyVoice-300M) -> `ComfyUI/models/CosyVoice/CosyVoice-300M` 66 | - [CosyVoice-300M-25Hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-25Hz` 67 | - [CosyVoice-300M-SFT](https://modelscope.cn/models/iic/CosyVoice-300M-SFT) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-SFT` 68 | - [CosyVoice-300M-SFT-25Hz](https://modelscope.cn/models/MachineS/CosyVoice-300M-SFT-25Hz) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-SFT-25Hz` 69 | - [CosyVoice-300M-Instruct](https://modelscope.cn/models/iic/CosyVoice-300M-Instruct) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-Instruct` 70 | - [SenseVoiceSmall](https://modelscope.cn/models/iic/SenseVoiceSmall) -> `ComfyUI/models/SenseVoice/SenseVoiceSmall` 71 | 72 | ## WorkFlow 73 | 74 | zh-CN 75 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 12:14:22 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-04 22:29:33 7 | Description: file content 8 | ''' 9 | from .nodes.cosyvoice_nodes import * 10 | from .nodes.sensevoice_nodes import * 11 | 12 | NODE_CONFIG = { 13 | "CosyVoiceZeroShotNode": { 14 | "class": CosyVoiceZeroShotNode, 15 | "name": "CosyVoice 3s极速克隆" 16 | }, 17 | "CosyVoiceSFTNode": { 18 | "class": CosyVoiceSFTNode, 19 | "name": "CosyVoice 预训练音色" 20 | }, 21 | "CosyVoiceCrossLingualNode": { 22 | "class": CosyVoiceCrossLingualNode, 23 | "name": "CosyVoice 跨语言克隆" 24 | }, 25 | "CosyVoiceInstructNode": { 26 | "class": CosyVoiceInstructNode, 27 | "name": "CosyVoice 自然语言控制" 28 | }, 29 | "CosyVoiceSaveSpeakerModelNode": { 30 | "class": CosyVoiceSaveSpeakerModelNode, 31 | "name": "CosyVoice 保存说话人模型" 32 | }, 33 | "CosyVoiceLoadSpeakerModelNode": { 34 | "class": CosyVoiceLoadSpeakerModelNode, 35 | "name": "CosyVoice 加载说话人模型" 36 | }, 37 | "CosyVoiceLoadSpeakerModelFromUrlNode": { 38 | "class": CosyVoiceLoadSpeakerModelFromUrlNode, 39 | "name": "CosyVoice 从URL加载说话人模型" 40 | }, 41 | "SenseVoiceNode": { 42 | "class": SenseVoiceNode, 43 | "name": "SenseVoice 语音识别" 44 | } 45 | } 46 | 47 | def generate_node_mappings(node_config): 48 | node_class_mappings = {} 49 | node_display_name_mappings = {} 50 | 51 | for node_name, node_info in node_config.items(): 52 | node_class_mappings[node_name] = node_info["class"] 53 | node_display_name_mappings[node_name] = node_info.get("name", node_info["class"].__name__) 54 | 55 | return node_class_mappings, node_display_name_mappings 56 | 57 | NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS = generate_node_mappings(NODE_CONFIG) 58 | 59 | WEB_DIRECTORY = "./web" 60 | 61 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"] 62 | 63 | 64 | -------------------------------------------------------------------------------- /assets/LoadSpeakerModel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/LoadSpeakerModel.png -------------------------------------------------------------------------------- /assets/PuncSegment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/PuncSegment.png -------------------------------------------------------------------------------- /assets/SaveSpeakerModel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/SaveSpeakerModel.png -------------------------------------------------------------------------------- /assets/SenseVoice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/SenseVoice.png -------------------------------------------------------------------------------- /assets/Workflow_FunAudioLLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/Workflow_FunAudioLLM.png -------------------------------------------------------------------------------- /cosyvoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/__init__.py -------------------------------------------------------------------------------- /cosyvoice/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from cosyvoice.cli.cosyvoice import CosyVoice 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description='export your model for deployment') 31 | parser.add_argument('--model_dir', 32 | type=str, 33 | default='pretrained_models/CosyVoice-300M', 34 | help='local path') 35 | args = parser.parse_args() 36 | print(args) 37 | return args 38 | 39 | 40 | def main(): 41 | args = get_args() 42 | logging.basicConfig(level=logging.DEBUG, 43 | format='%(asctime)s %(levelname)s %(message)s') 44 | 45 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 46 | torch._C._jit_set_profiling_mode(False) 47 | torch._C._jit_set_profiling_executor(False) 48 | 49 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) 50 | 51 | # 1. export llm text_encoder 52 | llm_text_encoder = cosyvoice.model.llm.text_encoder.half() 53 | script = torch.jit.script(llm_text_encoder) 54 | script = torch.jit.freeze(script) 55 | script = torch.jit.optimize_for_inference(script) 56 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 57 | 58 | # 2. export llm llm 59 | llm_llm = cosyvoice.model.llm.llm.half() 60 | script = torch.jit.script(llm_llm) 61 | script = torch.jit.freeze(script, preserved_attrs=['forward_chunk']) 62 | script = torch.jit.optimize_for_inference(script) 63 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 64 | 65 | # 3. export flow encoder 66 | flow_encoder = cosyvoice.model.flow.encoder 67 | script = torch.jit.script(flow_encoder) 68 | script = torch.jit.freeze(script) 69 | script = torch.jit.optimize_for_inference(script) 70 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from cosyvoice.cli.cosyvoice import CosyVoice 31 | 32 | 33 | def get_dummy_input(batch_size, seq_len, out_channels, device): 34 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 35 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 36 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 37 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 38 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 39 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 40 | return x, mask, mu, t, spks, cond 41 | 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser(description='export your model for deployment') 45 | parser.add_argument('--model_dir', 46 | type=str, 47 | default='pretrained_models/CosyVoice-300M', 48 | help='local path') 49 | args = parser.parse_args() 50 | print(args) 51 | return args 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | logging.basicConfig(level=logging.DEBUG, 57 | format='%(asctime)s %(levelname)s %(message)s') 58 | 59 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) 60 | 61 | # 1. export flow decoder estimator 62 | estimator = cosyvoice.model.flow.decoder.estimator 63 | 64 | device = cosyvoice.model.device 65 | batch_size, seq_len = 1, 256 66 | out_channels = cosyvoice.model.flow.decoder.estimator.out_channels 67 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 68 | torch.onnx.export( 69 | estimator, 70 | (x, mask, mu, t, spks, cond), 71 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 72 | export_params=True, 73 | opset_version=18, 74 | do_constant_folding=True, 75 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 76 | output_names=['estimator_out'], 77 | dynamic_axes={ 78 | 'x': {0: 'batch_size', 2: 'seq_len'}, 79 | 'mask': {0: 'batch_size', 2: 'seq_len'}, 80 | 'mu': {0: 'batch_size', 2: 'seq_len'}, 81 | 'cond': {0: 'batch_size', 2: 'seq_len'}, 82 | 't': {0: 'batch_size'}, 83 | 'spks': {0: 'batch_size'}, 84 | 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, 85 | } 86 | ) 87 | 88 | # 2. test computation consistency 89 | option = onnxruntime.SessionOptions() 90 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 91 | option.intra_op_num_threads = 1 92 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 93 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 94 | sess_options=option, providers=providers) 95 | 96 | for _ in tqdm(range(10)): 97 | x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) 98 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 99 | ort_inputs = { 100 | 'x': x.cpu().numpy(), 101 | 'mask': mask.cpu().numpy(), 102 | 'mu': mu.cpu().numpy(), 103 | 't': t.cpu().numpy(), 104 | 'spks': spks.cpu().numpy(), 105 | 'cond': cond.cpu().numpy() 106 | } 107 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 108 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /cosyvoice/bin/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import torchaudio 24 | from hyperpyyaml import load_hyperpyyaml 25 | from tqdm import tqdm 26 | from cosyvoice.cli.model import CosyVoiceModel 27 | from cosyvoice.dataset.dataset import Dataset 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='inference with your model') 32 | parser.add_argument('--config', required=True, help='config file') 33 | parser.add_argument('--prompt_data', required=True, help='prompt data file') 34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') 35 | parser.add_argument('--tts_text', required=True, help='tts input file') 36 | parser.add_argument('--llm_model', required=True, help='llm model file') 37 | parser.add_argument('--flow_model', required=True, help='flow model file') 38 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file') 39 | parser.add_argument('--gpu', 40 | type=int, 41 | default=-1, 42 | help='gpu id for this rank, -1 for cpu') 43 | parser.add_argument('--mode', 44 | default='sft', 45 | choices=['sft', 'zero_shot'], 46 | help='inference mode') 47 | parser.add_argument('--result_dir', required=True, help='asr result file') 48 | args = parser.parse_args() 49 | print(args) 50 | return args 51 | 52 | 53 | def main(): 54 | args = get_args() 55 | logging.basicConfig(level=logging.DEBUG, 56 | format='%(asctime)s %(levelname)s %(message)s') 57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 58 | 59 | # Init cosyvoice models from configs 60 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 61 | device = torch.device('cuda' if use_cuda else 'cpu') 62 | with open(args.config, 'r') as f: 63 | configs = load_hyperpyyaml(f) 64 | 65 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) 66 | model.load(args.llm_model, args.flow_model, args.hifigan_model) 67 | 68 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, 69 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) 70 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) 71 | 72 | del configs 73 | os.makedirs(args.result_dir, exist_ok=True) 74 | fn = os.path.join(args.result_dir, 'wav.scp') 75 | f = open(fn, 'w') 76 | with torch.no_grad(): 77 | for _, batch in tqdm(enumerate(test_data_loader)): 78 | utts = batch["utts"] 79 | assert len(utts) == 1, "inference mode only support batchsize 1" 80 | text_token = batch["text_token"].to(device) 81 | text_token_len = batch["text_token_len"].to(device) 82 | tts_index = batch["tts_index"] 83 | tts_text_token = batch["tts_text_token"].to(device) 84 | tts_text_token_len = batch["tts_text_token_len"].to(device) 85 | speech_token = batch["speech_token"].to(device) 86 | speech_token_len = batch["speech_token_len"].to(device) 87 | speech_feat = batch["speech_feat"].to(device) 88 | speech_feat_len = batch["speech_feat_len"].to(device) 89 | utt_embedding = batch["utt_embedding"].to(device) 90 | spk_embedding = batch["spk_embedding"].to(device) 91 | if args.mode == 'sft': 92 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 93 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} 94 | else: 95 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 96 | 'prompt_text': text_token, 'prompt_text_len': text_token_len, 97 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, 98 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 99 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 100 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 101 | tts_speeches = [] 102 | for model_output in model.inference(**model_input): 103 | tts_speeches.append(model_output['tts_speech']) 104 | tts_speeches = torch.concat(tts_speeches, dim=1) 105 | tts_key = '{}_{}'.format(utts[0], tts_index[0]) 106 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) 107 | torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) 108 | f.write('{} {}\n'.format(tts_key, tts_fn)) 109 | f.flush() 110 | f.close() 111 | logging.info('Result wav.scp saved in {}'.format(fn)) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /cosyvoice/bin/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | import argparse 17 | import datetime 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | from copy import deepcopy 21 | import torch 22 | import torch.distributed as dist 23 | import deepspeed 24 | 25 | from hyperpyyaml import load_hyperpyyaml 26 | 27 | from torch.distributed.elastic.multiprocessing.errors import record 28 | 29 | from cosyvoice.utils.executor import Executor 30 | from cosyvoice.utils.train_utils import ( 31 | init_distributed, 32 | init_dataset_and_dataloader, 33 | init_optimizer_and_scheduler, 34 | init_summarywriter, save_model, 35 | wrap_cuda_model, check_modify_and_save_config) 36 | 37 | 38 | def get_args(): 39 | parser = argparse.ArgumentParser(description='training your network') 40 | parser.add_argument('--train_engine', 41 | default='torch_ddp', 42 | choices=['torch_ddp', 'deepspeed'], 43 | help='Engine for paralleled training') 44 | parser.add_argument('--model', required=True, help='model which will be trained') 45 | parser.add_argument('--config', required=True, help='config file') 46 | parser.add_argument('--train_data', required=True, help='train data file') 47 | parser.add_argument('--cv_data', required=True, help='cv data file') 48 | parser.add_argument('--checkpoint', help='checkpoint model') 49 | parser.add_argument('--model_dir', required=True, help='save model dir') 50 | parser.add_argument('--tensorboard_dir', 51 | default='tensorboard', 52 | help='tensorboard log dir') 53 | parser.add_argument('--ddp.dist_backend', 54 | dest='dist_backend', 55 | default='nccl', 56 | choices=['nccl', 'gloo'], 57 | help='distributed backend') 58 | parser.add_argument('--num_workers', 59 | default=0, 60 | type=int, 61 | help='num of subprocess workers for reading') 62 | parser.add_argument('--prefetch', 63 | default=100, 64 | type=int, 65 | help='prefetch number') 66 | parser.add_argument('--pin_memory', 67 | action='store_true', 68 | default=False, 69 | help='Use pinned memory buffers used for reading') 70 | parser.add_argument('--deepspeed.save_states', 71 | dest='save_states', 72 | default='model_only', 73 | choices=['model_only', 'model+optimizer'], 74 | help='save model/optimizer states') 75 | parser.add_argument('--timeout', 76 | default=30, 77 | type=int, 78 | help='timeout (in seconds) of cosyvoice_join.') 79 | parser = deepspeed.add_config_arguments(parser) 80 | args = parser.parse_args() 81 | return args 82 | 83 | 84 | @record 85 | def main(): 86 | args = get_args() 87 | logging.basicConfig(level=logging.DEBUG, 88 | format='%(asctime)s %(levelname)s %(message)s') 89 | 90 | override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model} 91 | with open(args.config, 'r') as f: 92 | configs = load_hyperpyyaml(f, overrides=override_dict) 93 | configs['train_conf'].update(vars(args)) 94 | 95 | # Init env for ddp 96 | init_distributed(args) 97 | 98 | # Get dataset & dataloader 99 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ 100 | init_dataset_and_dataloader(args, configs) 101 | 102 | # Do some sanity checks and save config to arsg.model_dir 103 | configs = check_modify_and_save_config(args, configs) 104 | 105 | # Tensorboard summary 106 | writer = init_summarywriter(args) 107 | 108 | # load checkpoint 109 | model = configs[args.model] 110 | if args.checkpoint is not None: 111 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu')) 112 | 113 | # Dispatch model from cpu to gpu 114 | model = wrap_cuda_model(args, model) 115 | 116 | # Get optimizer & scheduler 117 | model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model) 118 | 119 | # Save init checkpoints 120 | info_dict = deepcopy(configs['train_conf']) 121 | save_model(model, 'init', info_dict) 122 | 123 | # Get executor 124 | executor = Executor() 125 | 126 | # Start training loop 127 | for epoch in range(info_dict['max_epoch']): 128 | executor.epoch = epoch 129 | train_dataset.set_epoch(epoch) 130 | dist.barrier() 131 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) 132 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join) 133 | dist.destroy_process_group(group_join) 134 | 135 | 136 | if __name__ == '__main__': 137 | main() 138 | -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/cli/__init__.py -------------------------------------------------------------------------------- /cosyvoice/cli/cosyvoice.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 11:30:15 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-04 14:31:23 7 | Description: file content 8 | ''' 9 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 10 | # 11 | # Licensed under the Apache License, Version 2.0 (the "License"); 12 | # you may not use this file except in compliance with the License. 13 | # You may obtain a copy of the License at 14 | # 15 | # http://www.apache.org/licenses/LICENSE-2.0 16 | # 17 | # Unless required by applicable law or agreed to in writing, software 18 | # distributed under the License is distributed on an "AS IS" BASIS, 19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | # See the License for the specific language governing permissions and 21 | # limitations under the License. 22 | import os 23 | import time 24 | from tqdm import tqdm 25 | from hyperpyyaml import load_hyperpyyaml 26 | from modelscope import snapshot_download 27 | from cosyvoice.cli.frontend import CosyVoiceFrontEnd 28 | from cosyvoice.cli.model import CosyVoiceModel 29 | from cosyvoice.utils.file_utils import logging 30 | 31 | 32 | class CosyVoice: 33 | 34 | def __init__(self, model_dir, load_jit=True, load_onnx=False): 35 | instruct = True if '-Instruct' in model_dir else False 36 | self.model_dir = model_dir 37 | if not os.path.exists(model_dir): 38 | model_dir = snapshot_download(model_dir) 39 | with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f: 40 | configs = load_hyperpyyaml(f) 41 | self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'], 42 | configs['feat_extractor'], 43 | '{}/campplus.onnx'.format(model_dir), 44 | '{}/speech_tokenizer_v1.onnx'.format(model_dir), 45 | '{}/spk2info.pt'.format(model_dir), 46 | instruct, 47 | configs['allowed_special']) 48 | self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) 49 | self.model.load('{}/llm.pt'.format(model_dir), 50 | '{}/flow.pt'.format(model_dir), 51 | '{}/hift.pt'.format(model_dir)) 52 | if load_jit: 53 | self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir), 54 | '{}/llm.llm.fp16.zip'.format(model_dir), 55 | '{}/flow.encoder.fp32.zip'.format(model_dir)) 56 | if load_onnx: 57 | self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir)) 58 | del configs 59 | 60 | def list_avaliable_spks(self): 61 | spks = list(self.frontend.spk2info.keys()) 62 | return spks 63 | 64 | def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0): 65 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): 66 | model_input = self.frontend.frontend_sft(i, spk_id) 67 | start_time = time.time() 68 | logging.info('synthesis text {}'.format(i)) 69 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed): 70 | speech_len = model_output['tts_speech'].shape[1] / 22050 71 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 72 | yield model_output 73 | start_time = time.time() 74 | 75 | def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0): 76 | prompt_text = self.frontend.text_normalize(prompt_text, split=False) 77 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): 78 | model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k) 79 | start_time = time.time() 80 | logging.info('synthesis text {}'.format(i)) 81 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed): 82 | speech_len = model_output['tts_speech'].shape[1] / 22050 83 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 84 | yield model_output 85 | start_time = time.time() 86 | 87 | def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0): 88 | if self.frontend.instruct is True: 89 | raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir)) 90 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): 91 | model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k) 92 | start_time = time.time() 93 | logging.info('synthesis text {}'.format(i)) 94 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed): 95 | speech_len = model_output['tts_speech'].shape[1] / 22050 96 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 97 | yield model_output 98 | start_time = time.time() 99 | 100 | def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0): 101 | if self.frontend.instruct is False: 102 | raise ValueError('{} do not support instruct inference'.format(self.model_dir)) 103 | instruct_text = self.frontend.text_normalize(instruct_text, split=False) 104 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): 105 | model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text) 106 | start_time = time.time() 107 | logging.info('synthesis text {}'.format(i)) 108 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed): 109 | speech_len = model_output['tts_speech'].shape[1] / 22050 110 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 111 | yield model_output 112 | start_time = time.time() 113 | 114 | def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0): 115 | model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k) 116 | start_time = time.time() 117 | for model_output in self.model.vc(**model_input, stream=stream, speed=speed): 118 | speech_len = model_output['tts_speech'].shape[1] / 22050 119 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 120 | yield model_output 121 | start_time = time.time() 122 | -------------------------------------------------------------------------------- /cosyvoice/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/dataset/__init__.py -------------------------------------------------------------------------------- /cosyvoice/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import json 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch.utils.data import IterableDataset 24 | from cosyvoice.utils.file_utils import read_lists, read_json_lists 25 | 26 | 27 | class Processor(IterableDataset): 28 | 29 | def __init__(self, source, f, *args, **kw): 30 | assert callable(f) 31 | self.source = source 32 | self.f = f 33 | self.args = args 34 | self.kw = kw 35 | 36 | def set_epoch(self, epoch): 37 | self.source.set_epoch(epoch) 38 | 39 | def __iter__(self): 40 | """ Return an iterator over the source dataset processed by the 41 | given processor. 42 | """ 43 | assert self.source is not None 44 | assert callable(self.f) 45 | return self.f(iter(self.source), *self.args, **self.kw) 46 | 47 | def apply(self, f): 48 | assert callable(f) 49 | return Processor(self, f, *self.args, **self.kw) 50 | 51 | 52 | class DistributedSampler: 53 | 54 | def __init__(self, shuffle=True, partition=True): 55 | self.epoch = -1 56 | self.update() 57 | self.shuffle = shuffle 58 | self.partition = partition 59 | 60 | def update(self): 61 | assert dist.is_available() 62 | if dist.is_initialized(): 63 | self.rank = dist.get_rank() 64 | self.world_size = dist.get_world_size() 65 | else: 66 | self.rank = 0 67 | self.world_size = 1 68 | worker_info = torch.utils.data.get_worker_info() 69 | if worker_info is None: 70 | self.worker_id = 0 71 | self.num_workers = 1 72 | else: 73 | self.worker_id = worker_info.id 74 | self.num_workers = worker_info.num_workers 75 | return dict(rank=self.rank, 76 | world_size=self.world_size, 77 | worker_id=self.worker_id, 78 | num_workers=self.num_workers) 79 | 80 | def set_epoch(self, epoch): 81 | self.epoch = epoch 82 | 83 | def sample(self, data): 84 | """ Sample data according to rank/world_size/num_workers 85 | 86 | Args: 87 | data(List): input data list 88 | 89 | Returns: 90 | List: data list after sample 91 | """ 92 | data = list(range(len(data))) 93 | # force datalist even 94 | if self.partition: 95 | if self.shuffle: 96 | random.Random(self.epoch).shuffle(data) 97 | if len(data) < self.world_size: 98 | data = data * math.ceil(self.world_size / len(data)) 99 | data = data[:self.world_size] 100 | data = data[self.rank::self.world_size] 101 | if len(data) < self.num_workers: 102 | data = data * math.ceil(self.num_workers / len(data)) 103 | data = data[:self.num_workers] 104 | data = data[self.worker_id::self.num_workers] 105 | return data 106 | 107 | 108 | class DataList(IterableDataset): 109 | 110 | def __init__(self, lists, shuffle=True, partition=True): 111 | self.lists = lists 112 | self.sampler = DistributedSampler(shuffle, partition) 113 | 114 | def set_epoch(self, epoch): 115 | self.sampler.set_epoch(epoch) 116 | 117 | def __iter__(self): 118 | sampler_info = self.sampler.update() 119 | indexes = self.sampler.sample(self.lists) 120 | for index in indexes: 121 | data = dict(src=self.lists[index]) 122 | data.update(sampler_info) 123 | yield data 124 | 125 | 126 | def Dataset(data_list_file, 127 | data_pipeline, 128 | mode='train', 129 | shuffle=True, 130 | partition=True, 131 | tts_file='', 132 | prompt_utt2data=''): 133 | """ Construct dataset from arguments 134 | 135 | We have two shuffle stage in the Dataset. The first is global 136 | shuffle at shards tar/raw file level. The second is global shuffle 137 | at training samples level. 138 | 139 | Args: 140 | data_type(str): raw/shard 141 | tokenizer (BaseTokenizer): tokenizer to tokenize 142 | partition(bool): whether to do data partition in terms of rank 143 | """ 144 | assert mode in ['train', 'inference'] 145 | lists = read_lists(data_list_file) 146 | if mode == 'inference': 147 | with open(tts_file) as f: 148 | tts_data = json.load(f) 149 | utt2lists = read_json_lists(prompt_utt2data) 150 | # filter unnecessary file in inference mode 151 | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists}) 152 | dataset = DataList(lists, 153 | shuffle=shuffle, 154 | partition=partition) 155 | if mode == 'inference': 156 | # map partial arg tts_data in inference mode 157 | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data) 158 | for func in data_pipeline: 159 | dataset = Processor(dataset, func, mode=mode) 160 | return dataset 161 | -------------------------------------------------------------------------------- /cosyvoice/flow/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import random 16 | from typing import Dict, Optional 17 | import torch 18 | import torch.nn as nn 19 | from torch.nn import functional as F 20 | from omegaconf import DictConfig 21 | from cosyvoice.utils.mask import make_pad_mask 22 | 23 | 24 | class MaskedDiffWithXvec(torch.nn.Module): 25 | def __init__(self, 26 | input_size: int = 512, 27 | output_size: int = 80, 28 | spk_embed_dim: int = 192, 29 | output_type: str = "mel", 30 | vocab_size: int = 4096, 31 | input_frame_rate: int = 50, 32 | only_mask_loss: bool = True, 33 | encoder: torch.nn.Module = None, 34 | length_regulator: torch.nn.Module = None, 35 | decoder: torch.nn.Module = None, 36 | decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 37 | 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 38 | 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 39 | 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 40 | 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}, 41 | mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 42 | 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}): 43 | super().__init__() 44 | self.input_size = input_size 45 | self.output_size = output_size 46 | self.decoder_conf = decoder_conf 47 | self.mel_feat_conf = mel_feat_conf 48 | self.vocab_size = vocab_size 49 | self.output_type = output_type 50 | self.input_frame_rate = input_frame_rate 51 | logging.info(f"input frame rate={self.input_frame_rate}") 52 | self.input_embedding = nn.Embedding(vocab_size, input_size) 53 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size) 54 | self.encoder = encoder 55 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size) 56 | self.decoder = decoder 57 | self.length_regulator = length_regulator 58 | self.only_mask_loss = only_mask_loss 59 | 60 | def forward( 61 | self, 62 | batch: dict, 63 | device: torch.device, 64 | ) -> Dict[str, Optional[torch.Tensor]]: 65 | token = batch['speech_token'].to(device) 66 | token_len = batch['speech_token_len'].to(device) 67 | feat = batch['speech_feat'].to(device) 68 | feat_len = batch['speech_feat_len'].to(device) 69 | embedding = batch['embedding'].to(device) 70 | 71 | # xvec projection 72 | embedding = F.normalize(embedding, dim=1) 73 | embedding = self.spk_embed_affine_layer(embedding) 74 | 75 | # concat text and prompt_text 76 | mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device) 77 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 78 | 79 | # text encode 80 | h, h_lengths = self.encoder(token, token_len) 81 | h = self.encoder_proj(h) 82 | h, h_lengths = self.length_regulator(h, feat_len) 83 | 84 | # get conditions 85 | conds = torch.zeros(feat.shape, device=token.device) 86 | for i, j in enumerate(feat_len): 87 | if random.random() < 0.5: 88 | continue 89 | index = random.randint(0, int(0.3 * j)) 90 | conds[i, :index] = feat[i, :index] 91 | conds = conds.transpose(1, 2) 92 | 93 | mask = (~make_pad_mask(feat_len)).to(h) 94 | feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1) 95 | loss, _ = self.decoder.compute_loss( 96 | feat.transpose(1, 2).contiguous(), 97 | mask.unsqueeze(1), 98 | h.transpose(1, 2).contiguous(), 99 | embedding, 100 | cond=conds 101 | ) 102 | return {'loss': loss} 103 | 104 | @torch.inference_mode() 105 | def inference(self, 106 | token, 107 | token_len, 108 | prompt_token, 109 | prompt_token_len, 110 | prompt_feat, 111 | prompt_feat_len, 112 | embedding): 113 | assert token.shape[0] == 1 114 | # xvec projection 115 | embedding = F.normalize(embedding, dim=1) 116 | embedding = self.spk_embed_affine_layer(embedding) 117 | 118 | # concat text and prompt_text 119 | token_len1, token_len2 = prompt_token.shape[1], token.shape[1] 120 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len 121 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding) 122 | token = self.input_embedding(torch.clamp(token, min=0)) * mask 123 | 124 | # text encode 125 | h, h_lengths = self.encoder(token, token_len) 126 | h = self.encoder_proj(h) 127 | mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256) 128 | h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate) 129 | 130 | # get conditions 131 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device) 132 | conds[:, :mel_len1] = prompt_feat 133 | conds = conds.transpose(1, 2) 134 | 135 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h) 136 | feat = self.decoder( 137 | mu=h.transpose(1, 2).contiguous(), 138 | mask=mask.unsqueeze(1), 139 | spks=embedding, 140 | cond=conds, 141 | n_timesteps=10 142 | ) 143 | feat = feat[:, :, mel_len1:] 144 | assert feat.shape[2] == mel_len2 145 | return feat 146 | -------------------------------------------------------------------------------- /cosyvoice/flow/flow_matching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn.functional as F 16 | from matcha.models.components.flow_matching import BASECFM 17 | 18 | 19 | class ConditionalCFM(BASECFM): 20 | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None): 21 | super().__init__( 22 | n_feats=in_channels, 23 | cfm_params=cfm_params, 24 | n_spks=n_spks, 25 | spk_emb_dim=spk_emb_dim, 26 | ) 27 | self.t_scheduler = cfm_params.t_scheduler 28 | self.training_cfg_rate = cfm_params.training_cfg_rate 29 | self.inference_cfg_rate = cfm_params.inference_cfg_rate 30 | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0) 31 | # Just change the architecture of the estimator here 32 | self.estimator = estimator 33 | 34 | @torch.inference_mode() 35 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 36 | """Forward diffusion 37 | 38 | Args: 39 | mu (torch.Tensor): output of encoder 40 | shape: (batch_size, n_feats, mel_timesteps) 41 | mask (torch.Tensor): output_mask 42 | shape: (batch_size, 1, mel_timesteps) 43 | n_timesteps (int): number of diffusion steps 44 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 45 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 46 | shape: (batch_size, spk_emb_dim) 47 | cond: Not used but kept for future purposes 48 | 49 | Returns: 50 | sample: generated mel-spectrogram 51 | shape: (batch_size, n_feats, mel_timesteps) 52 | """ 53 | z = torch.randn_like(mu) * temperature 54 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype) 55 | if self.t_scheduler == 'cosine': 56 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi) 57 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 58 | 59 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 60 | """ 61 | Fixed euler solver for ODEs. 62 | Args: 63 | x (torch.Tensor): random noise 64 | t_span (torch.Tensor): n_timesteps interpolated 65 | shape: (n_timesteps + 1,) 66 | mu (torch.Tensor): output of encoder 67 | shape: (batch_size, n_feats, mel_timesteps) 68 | mask (torch.Tensor): output_mask 69 | shape: (batch_size, 1, mel_timesteps) 70 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 71 | shape: (batch_size, spk_emb_dim) 72 | cond: Not used but kept for future purposes 73 | """ 74 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 75 | t = t.unsqueeze(dim=0) 76 | 77 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 78 | # Or in future might add like a return_all_steps flag 79 | sol = [] 80 | 81 | for step in range(1, len(t_span)): 82 | dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond) 83 | # Classifier-Free Guidance inference introduced in VoiceBox 84 | if self.inference_cfg_rate > 0: 85 | cfg_dphi_dt = self.forward_estimator( 86 | x, mask, 87 | torch.zeros_like(mu), t, 88 | torch.zeros_like(spks) if spks is not None else None, 89 | torch.zeros_like(cond) 90 | ) 91 | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - 92 | self.inference_cfg_rate * cfg_dphi_dt) 93 | x = x + dt * dphi_dt 94 | t = t + dt 95 | sol.append(x) 96 | if step < len(t_span) - 1: 97 | dt = t_span[step + 1] - t 98 | 99 | return sol[-1] 100 | 101 | def forward_estimator(self, x, mask, mu, t, spks, cond): 102 | if isinstance(self.estimator, torch.nn.Module): 103 | return self.estimator.forward(x, mask, mu, t, spks, cond) 104 | else: 105 | ort_inputs = { 106 | 'x': x.cpu().numpy(), 107 | 'mask': mask.cpu().numpy(), 108 | 'mu': mu.cpu().numpy(), 109 | 't': t.cpu().numpy(), 110 | 'spks': spks.cpu().numpy(), 111 | 'cond': cond.cpu().numpy() 112 | } 113 | output = self.estimator.run(None, ort_inputs)[0] 114 | return torch.tensor(output, dtype=x.dtype, device=x.device) 115 | 116 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 117 | """Computes diffusion loss 118 | 119 | Args: 120 | x1 (torch.Tensor): Target 121 | shape: (batch_size, n_feats, mel_timesteps) 122 | mask (torch.Tensor): target mask 123 | shape: (batch_size, 1, mel_timesteps) 124 | mu (torch.Tensor): output of encoder 125 | shape: (batch_size, n_feats, mel_timesteps) 126 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 127 | shape: (batch_size, spk_emb_dim) 128 | 129 | Returns: 130 | loss: conditional flow matching loss 131 | y: conditional flow 132 | shape: (batch_size, n_feats, mel_timesteps) 133 | """ 134 | b, _, t = mu.shape 135 | 136 | # random timestep 137 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 138 | if self.t_scheduler == 'cosine': 139 | t = 1 - torch.cos(t * 0.5 * torch.pi) 140 | # sample noise p(x_0) 141 | z = torch.randn_like(x1) 142 | 143 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 144 | u = x1 - (1 - self.sigma_min) * z 145 | 146 | # during training, we randomly drop condition to trade off mode coverage and sample fidelity 147 | if self.training_cfg_rate > 0: 148 | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate 149 | mu = mu * cfg_mask.view(-1, 1, 1) 150 | spks = spks * cfg_mask.view(-1, 1) 151 | cond = cond * cfg_mask.view(-1, 1, 1) 152 | 153 | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond) 154 | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1]) 155 | return loss, y 156 | -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # x in (B, T, D) 55 | if x2.shape[1] > 40: 56 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 57 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 58 | mode='linear') 59 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 60 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 61 | else: 62 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 63 | if x1.shape[1] != 0: 64 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 65 | x = torch.concat([x1, x2], dim=2) 66 | else: 67 | x = x2 68 | out = self.model(x).transpose(1, 2).contiguous() 69 | return out, mel_len1 + mel_len2 70 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn.utils import weight_norm 17 | 18 | 19 | class ConvRNNF0Predictor(nn.Module): 20 | def __init__(self, 21 | num_class: int = 1, 22 | in_channels: int = 80, 23 | cond_channels: int = 512 24 | ): 25 | super().__init__() 26 | 27 | self.num_class = num_class 28 | self.condnet = nn.Sequential( 29 | weight_norm( 30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 31 | ), 32 | nn.ELU(), 33 | weight_norm( 34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 35 | ), 36 | nn.ELU(), 37 | weight_norm( 38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 39 | ), 40 | nn.ELU(), 41 | weight_norm( 42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 43 | ), 44 | nn.ELU(), 45 | weight_norm( 46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 47 | ), 48 | nn.ELU(), 49 | ) 50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.condnet(x) 54 | x = x.transpose(1, 2) 55 | return torch.abs(self.classifier(x).squeeze(-1)) 56 | -------------------------------------------------------------------------------- /cosyvoice/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | from functools import lru_cache 4 | from typing import Optional 5 | from whisper.tokenizer import Tokenizer 6 | 7 | import tiktoken 8 | 9 | LANGUAGES = { 10 | "en": "english", 11 | "zh": "chinese", 12 | "de": "german", 13 | "es": "spanish", 14 | "ru": "russian", 15 | "ko": "korean", 16 | "fr": "french", 17 | "ja": "japanese", 18 | "pt": "portuguese", 19 | "tr": "turkish", 20 | "pl": "polish", 21 | "ca": "catalan", 22 | "nl": "dutch", 23 | "ar": "arabic", 24 | "sv": "swedish", 25 | "it": "italian", 26 | "id": "indonesian", 27 | "hi": "hindi", 28 | "fi": "finnish", 29 | "vi": "vietnamese", 30 | "he": "hebrew", 31 | "uk": "ukrainian", 32 | "el": "greek", 33 | "ms": "malay", 34 | "cs": "czech", 35 | "ro": "romanian", 36 | "da": "danish", 37 | "hu": "hungarian", 38 | "ta": "tamil", 39 | "no": "norwegian", 40 | "th": "thai", 41 | "ur": "urdu", 42 | "hr": "croatian", 43 | "bg": "bulgarian", 44 | "lt": "lithuanian", 45 | "la": "latin", 46 | "mi": "maori", 47 | "ml": "malayalam", 48 | "cy": "welsh", 49 | "sk": "slovak", 50 | "te": "telugu", 51 | "fa": "persian", 52 | "lv": "latvian", 53 | "bn": "bengali", 54 | "sr": "serbian", 55 | "az": "azerbaijani", 56 | "sl": "slovenian", 57 | "kn": "kannada", 58 | "et": "estonian", 59 | "mk": "macedonian", 60 | "br": "breton", 61 | "eu": "basque", 62 | "is": "icelandic", 63 | "hy": "armenian", 64 | "ne": "nepali", 65 | "mn": "mongolian", 66 | "bs": "bosnian", 67 | "kk": "kazakh", 68 | "sq": "albanian", 69 | "sw": "swahili", 70 | "gl": "galician", 71 | "mr": "marathi", 72 | "pa": "punjabi", 73 | "si": "sinhala", 74 | "km": "khmer", 75 | "sn": "shona", 76 | "yo": "yoruba", 77 | "so": "somali", 78 | "af": "afrikaans", 79 | "oc": "occitan", 80 | "ka": "georgian", 81 | "be": "belarusian", 82 | "tg": "tajik", 83 | "sd": "sindhi", 84 | "gu": "gujarati", 85 | "am": "amharic", 86 | "yi": "yiddish", 87 | "lo": "lao", 88 | "uz": "uzbek", 89 | "fo": "faroese", 90 | "ht": "haitian creole", 91 | "ps": "pashto", 92 | "tk": "turkmen", 93 | "nn": "nynorsk", 94 | "mt": "maltese", 95 | "sa": "sanskrit", 96 | "lb": "luxembourgish", 97 | "my": "myanmar", 98 | "bo": "tibetan", 99 | "tl": "tagalog", 100 | "mg": "malagasy", 101 | "as": "assamese", 102 | "tt": "tatar", 103 | "haw": "hawaiian", 104 | "ln": "lingala", 105 | "ha": "hausa", 106 | "ba": "bashkir", 107 | "jw": "javanese", 108 | "su": "sundanese", 109 | "yue": "cantonese", 110 | "minnan": "minnan", 111 | "wuyu": "wuyu", 112 | "dialect": "dialect", 113 | "zh/en": "zh/en", 114 | "en/zh": "en/zh", 115 | } 116 | 117 | # language code lookup by name, with a few language aliases 118 | TO_LANGUAGE_CODE = { 119 | **{language: code for code, language in LANGUAGES.items()}, 120 | "burmese": "my", 121 | "valencian": "ca", 122 | "flemish": "nl", 123 | "haitian": "ht", 124 | "letzeburgesch": "lb", 125 | "pushto": "ps", 126 | "panjabi": "pa", 127 | "moldavian": "ro", 128 | "moldovan": "ro", 129 | "sinhalese": "si", 130 | "castilian": "es", 131 | "mandarin": "zh", 132 | } 133 | 134 | AUDIO_EVENT = { 135 | "ASR": "ASR", 136 | "AED": "AED", 137 | "SER": "SER", 138 | "Speech": "Speech", 139 | "/Speech": "/Speech", 140 | "BGM": "BGM", 141 | "/BGM": "/BGM", 142 | "Laughter": "Laughter", 143 | "/Laughter": "/Laughter", 144 | "Applause": "Applause", 145 | "/Applause": "/Applause", 146 | } 147 | 148 | EMOTION = { 149 | "HAPPY": "HAPPY", 150 | "SAD": "SAD", 151 | "ANGRY": "ANGRY", 152 | "NEUTRAL": "NEUTRAL", 153 | } 154 | 155 | TTS_Vocal_Token = { 156 | "TTS/B": "TTS/B", 157 | "TTS/O": "TTS/O", 158 | "TTS/Q": "TTS/Q", 159 | "TTS/A": "TTS/A", 160 | "TTS/CO": "TTS/CO", 161 | "TTS/CL": "TTS/CL", 162 | "TTS/H": "TTS/H", 163 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} 164 | } 165 | 166 | 167 | @lru_cache(maxsize=None) 168 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 169 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 170 | ranks = { 171 | base64.b64decode(token): int(rank) 172 | for token, rank in (line.split() for line in open(vocab_path) if line) 173 | } 174 | n_vocab = len(ranks) 175 | special_tokens = {} 176 | 177 | specials = [ 178 | "<|endoftext|>", 179 | "<|startoftranscript|>", 180 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 181 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], 182 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], 183 | "<|translate|>", 184 | "<|transcribe|>", 185 | "<|startoflm|>", 186 | "<|startofprev|>", 187 | "<|nospeech|>", 188 | "<|notimestamps|>", 189 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR 190 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS 191 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 192 | ] 193 | 194 | for token in specials: 195 | special_tokens[token] = n_vocab 196 | n_vocab += 1 197 | 198 | return tiktoken.Encoding( 199 | name=os.path.basename(vocab_path), 200 | explicit_n_vocab=n_vocab, 201 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 202 | mergeable_ranks=ranks, 203 | special_tokens=special_tokens, 204 | ) 205 | 206 | 207 | @lru_cache(maxsize=None) 208 | def get_tokenizer( 209 | multilingual: bool, 210 | *, 211 | num_languages: int = 99, 212 | language: Optional[str] = None, 213 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 214 | ) -> Tokenizer: 215 | if language is not None: 216 | language = language.lower() 217 | if language not in LANGUAGES: 218 | if language in TO_LANGUAGE_CODE: 219 | language = TO_LANGUAGE_CODE[language] 220 | else: 221 | raise ValueError(f"Unsupported language: {language}") 222 | 223 | if multilingual: 224 | encoding_name = "multilingual_zh_ja_yue_char_del" 225 | language = language or "en" 226 | task = task or "transcribe" 227 | else: 228 | encoding_name = "gpt2" 229 | language = None 230 | task = None 231 | 232 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 233 | 234 | return Tokenizer( 235 | encoding=encoding, num_languages=num_languages, language=language, task=task 236 | ) 237 | -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/transformer/__init__.py -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /cosyvoice/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /cosyvoice/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /cosyvoice/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /cosyvoice/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/utils/__init__.py -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from cosyvoice.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding 34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling 35 | 36 | 37 | COSYVOICE_ACTIVATION_CLASSES = { 38 | "hardtanh": torch.nn.Hardtanh, 39 | "tanh": torch.nn.Tanh, 40 | "relu": torch.nn.ReLU, 41 | "selu": torch.nn.SELU, 42 | "swish": getattr(torch.nn, "SiLU", Swish), 43 | "gelu": torch.nn.GELU, 44 | } 45 | 46 | COSYVOICE_SUBSAMPLE_CLASSES = { 47 | "linear": LinearNoSubsampling, 48 | "linear_legacy": LegacyLinearNoSubsampling, 49 | "embed": EmbedinigNoSubsampling, 50 | "conv1d2": Conv1dSubsampling2, 51 | "conv2d": Conv2dSubsampling4, 52 | "conv2d6": Conv2dSubsampling6, 53 | "conv2d8": Conv2dSubsampling8, 54 | 'paraformer_dummy': torch.nn.Identity 55 | } 56 | 57 | COSYVOICE_EMB_CLASSES = { 58 | "embed": PositionalEncoding, 59 | "abs_pos": PositionalEncoding, 60 | "rel_pos": RelPositionalEncoding, 61 | "rel_pos_espnet": EspnetRelPositionalEncoding, 62 | "no_pos": NoPositionalEncoding, 63 | "abs_pos_whisper": WhisperPositionalEncoding, 64 | "embed_learnable_pe": LearnablePositionalEncoding, 65 | } 66 | 67 | COSYVOICE_ATTENTION_CLASSES = { 68 | "selfattn": MultiHeadedAttention, 69 | "rel_selfattn": RelPositionMultiHeadedAttention, 70 | } 71 | -------------------------------------------------------------------------------- /cosyvoice/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Unility functions for Transformer.""" 17 | 18 | import random 19 | from typing import List 20 | 21 | import numpy as np 22 | import torch 23 | 24 | IGNORE_ID = -1 25 | 26 | 27 | def pad_list(xs: List[torch.Tensor], pad_value: int): 28 | """Perform padding for the list of tensors. 29 | 30 | Args: 31 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 32 | pad_value (float): Value for padding. 33 | 34 | Returns: 35 | Tensor: Padded tensor (B, Tmax, `*`). 36 | 37 | Examples: 38 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 39 | >>> x 40 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 41 | >>> pad_list(x, 0) 42 | tensor([[1., 1., 1., 1.], 43 | [1., 1., 0., 0.], 44 | [1., 0., 0., 0.]]) 45 | 46 | """ 47 | max_len = max([len(item) for item in xs]) 48 | batchs = len(xs) 49 | ndim = xs[0].ndim 50 | if ndim == 1: 51 | pad_res = torch.zeros(batchs, 52 | max_len, 53 | dtype=xs[0].dtype, 54 | device=xs[0].device) 55 | elif ndim == 2: 56 | pad_res = torch.zeros(batchs, 57 | max_len, 58 | xs[0].shape[1], 59 | dtype=xs[0].dtype, 60 | device=xs[0].device) 61 | elif ndim == 3: 62 | pad_res = torch.zeros(batchs, 63 | max_len, 64 | xs[0].shape[1], 65 | xs[0].shape[2], 66 | dtype=xs[0].dtype, 67 | device=xs[0].device) 68 | else: 69 | raise ValueError(f"Unsupported ndim: {ndim}") 70 | pad_res.fill_(pad_value) 71 | for i in range(batchs): 72 | pad_res[i, :len(xs[i])] = xs[i] 73 | return pad_res 74 | 75 | 76 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 77 | ignore_label: int) -> torch.Tensor: 78 | """Calculate accuracy. 79 | 80 | Args: 81 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 82 | pad_targets (LongTensor): Target label tensors (B, Lmax). 83 | ignore_label (int): Ignore label id. 84 | 85 | Returns: 86 | torch.Tensor: Accuracy value (0.0 - 1.0). 87 | 88 | """ 89 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 90 | pad_outputs.size(1)).argmax(2) 91 | mask = pad_targets != ignore_label 92 | numerator = torch.sum( 93 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 94 | denominator = torch.sum(mask) 95 | return (numerator / denominator).detach() 96 | 97 | 98 | def get_padding(kernel_size, dilation=1): 99 | return int((kernel_size * dilation - dilation) / 2) 100 | 101 | 102 | def init_weights(m, mean=0.0, std=0.01): 103 | classname = m.__class__.__name__ 104 | if classname.find("Conv") != -1: 105 | m.weight.data.normal_(mean, std) 106 | 107 | 108 | # Repetition Aware Sampling in VALL-E 2 109 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): 110 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) 111 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() 112 | if rep_num >= win_size * tau_r: 113 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) 114 | return top_ids 115 | 116 | 117 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): 118 | prob, indices = [], [] 119 | cum_prob = 0.0 120 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) 121 | for i in range(len(sorted_idx)): 122 | # sampling both top-p and numbers. 123 | if cum_prob < top_p and len(prob) < top_k: 124 | cum_prob += sorted_value[i] 125 | prob.append(sorted_value[i]) 126 | indices.append(sorted_idx[i]) 127 | else: 128 | break 129 | prob = torch.tensor(prob).to(weighted_scores) 130 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) 131 | top_ids = indices[prob.multinomial(1, replacement=True)] 132 | return top_ids 133 | 134 | 135 | def random_sampling(weighted_scores, decoded_tokens, sampling): 136 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) 137 | return top_ids 138 | 139 | 140 | def fade_in_out(fade_in_mel, fade_out_mel, window): 141 | device = fade_in_mel.device 142 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() 143 | mel_overlap_len = int(window.shape[0] / 2) 144 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ 145 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] 146 | return fade_in_mel.to(device) 147 | 148 | 149 | def set_all_random_seed(seed): 150 | random.seed(seed) 151 | np.random.seed(seed) 152 | torch.manual_seed(seed) 153 | torch.cuda.manual_seed_all(seed) 154 | -------------------------------------------------------------------------------- /cosyvoice/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | from contextlib import nullcontext 18 | import os 19 | 20 | import torch 21 | import torch.distributed as dist 22 | 23 | from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join 24 | 25 | 26 | class Executor: 27 | 28 | def __init__(self): 29 | self.step = 0 30 | self.epoch = 0 31 | self.rank = int(os.environ.get('RANK', 0)) 32 | self.device = torch.device('cuda:{}'.format(self.rank)) 33 | 34 | def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join): 35 | ''' Train one epoch 36 | ''' 37 | 38 | lr = optimizer.param_groups[0]['lr'] 39 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) 40 | logging.info('using accumulate grad, new batch size is {} times' 41 | ' larger than before'.format(info_dict['accum_grad'])) 42 | # A context manager to be used in conjunction with an instance of 43 | # torch.nn.parallel.DistributedDataParallel to be able to train 44 | # with uneven inputs across participating processes. 45 | model.train() 46 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext 47 | with model_context(): 48 | for batch_idx, batch_dict in enumerate(train_data_loader): 49 | info_dict["tag"] = "TRAIN" 50 | info_dict["step"] = self.step 51 | info_dict["epoch"] = self.epoch 52 | info_dict["batch_idx"] = batch_idx 53 | if cosyvoice_join(group_join, info_dict): 54 | break 55 | 56 | # Disable gradient synchronizations across DDP processes. 57 | # Within this context, gradients will be accumulated on module 58 | # variables, which will later be synchronized. 59 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: 60 | context = model.no_sync 61 | # Used for single gpu training and DDP gradient synchronization 62 | # processes. 63 | else: 64 | context = nullcontext 65 | 66 | with context(): 67 | info_dict = batch_forward(model, batch_dict, info_dict) 68 | info_dict = batch_backward(model, info_dict) 69 | 70 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict) 71 | log_per_step(writer, info_dict) 72 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save 73 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ 74 | (batch_idx + 1) % info_dict["accum_grad"] == 0: 75 | dist.barrier() 76 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) 77 | model.train() 78 | if (batch_idx + 1) % info_dict["accum_grad"] == 0: 79 | self.step += 1 80 | dist.barrier() 81 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) 82 | 83 | @torch.inference_mode() 84 | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True): 85 | ''' Cross validation on 86 | ''' 87 | logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) 88 | model.eval() 89 | total_num_utts, total_loss_dict = 0, {} # avoid division by 0 90 | for batch_idx, batch_dict in enumerate(cv_data_loader): 91 | info_dict["tag"] = "CV" 92 | info_dict["step"] = self.step 93 | info_dict["epoch"] = self.epoch 94 | info_dict["batch_idx"] = batch_idx 95 | 96 | num_utts = len(batch_dict["utts"]) 97 | total_num_utts += num_utts 98 | 99 | info_dict = batch_forward(model, batch_dict, info_dict) 100 | 101 | for k, v in info_dict['loss_dict'].items(): 102 | if k not in total_loss_dict: 103 | total_loss_dict[k] = [] 104 | total_loss_dict[k].append(v.item() * num_utts) 105 | log_per_step(None, info_dict) 106 | for k, v in total_loss_dict.items(): 107 | total_loss_dict[k] = sum(v) / total_num_utts 108 | info_dict['loss_dict'] = total_loss_dict 109 | log_per_save(writer, info_dict) 110 | model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) 111 | save_model(model, model_name, info_dict) 112 | -------------------------------------------------------------------------------- /cosyvoice/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import torchaudio 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | logging.basicConfig(level=logging.DEBUG, 21 | format='%(asctime)s %(levelname)s %(message)s') 22 | 23 | 24 | def read_lists(list_file): 25 | lists = [] 26 | with open(list_file, 'r', encoding='utf8') as fin: 27 | for line in fin: 28 | lists.append(line.strip()) 29 | return lists 30 | 31 | 32 | def read_json_lists(list_file): 33 | lists = read_lists(list_file) 34 | results = {} 35 | for fn in lists: 36 | with open(fn, 'r', encoding='utf8') as fin: 37 | results.update(json.load(fin)) 38 | return results 39 | 40 | 41 | def load_wav(wav, target_sr): 42 | speech, sample_rate = torchaudio.load(wav) 43 | speech = speech.mean(dim=0, keepdim=True) 44 | if sample_rate != target_sr: 45 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 46 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 47 | return speech 48 | -------------------------------------------------------------------------------- /cosyvoice/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 17 | 18 | 19 | # whether contain chinese character 20 | def contains_chinese(text): 21 | return bool(chinese_char_pattern.search(text)) 22 | 23 | 24 | # replace special symbol 25 | def replace_corner_mark(text): 26 | text = text.replace('²', '平方') 27 | text = text.replace('³', '立方') 28 | return text 29 | 30 | 31 | # remove meaningless symbol 32 | def remove_bracket(text): 33 | text = text.replace('(', '').replace(')', '') 34 | text = text.replace('【', '').replace('】', '') 35 | text = text.replace('`', '').replace('`', '') 36 | text = text.replace("——", " ") 37 | return text 38 | 39 | 40 | # spell Arabic numerals 41 | def spell_out_number(text: str, inflect_parser): 42 | new_text = [] 43 | st = None 44 | for i, c in enumerate(text): 45 | if not c.isdigit(): 46 | if st is not None: 47 | num_str = inflect_parser.number_to_words(text[st: i]) 48 | new_text.append(num_str) 49 | st = None 50 | new_text.append(c) 51 | else: 52 | if st is None: 53 | st = i 54 | if st is not None and st < len(text): 55 | num_str = inflect_parser.number_to_words(text[st:]) 56 | new_text.append(num_str) 57 | return ''.join(new_text) 58 | 59 | 60 | # split paragrah logic: 61 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 62 | # 2. cal sentence len according to lang 63 | # 3. split sentence according to puncatation 64 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 65 | def calc_utt_length(_text: str): 66 | if lang == "zh": 67 | return len(_text) 68 | else: 69 | return len(tokenize(_text)) 70 | 71 | def should_merge(_text: str): 72 | if lang == "zh": 73 | return len(_text) < merge_len 74 | else: 75 | return len(tokenize(_text)) < merge_len 76 | 77 | if lang == "zh": 78 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 79 | else: 80 | pounc = ['.', '?', '!', ';', ':'] 81 | if comma_split: 82 | pounc.extend([',', ',']) 83 | 84 | if text[-1] not in pounc: 85 | if lang == "zh": 86 | text += "。" 87 | else: 88 | text += "." 89 | 90 | st = 0 91 | utts = [] 92 | for i, c in enumerate(text): 93 | if c in pounc: 94 | if len(text[st: i]) > 0: 95 | utts.append(text[st: i] + c) 96 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 97 | tmp = utts.pop(-1) 98 | utts.append(tmp + text[i + 1]) 99 | st = i + 2 100 | else: 101 | st = i + 1 102 | 103 | final_utts = [] 104 | cur_utt = "" 105 | for utt in utts: 106 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 107 | final_utts.append(cur_utt) 108 | cur_utt = "" 109 | cur_utt = cur_utt + utt 110 | if len(cur_utt) > 0: 111 | if should_merge(cur_utt) and len(final_utts) != 0: 112 | final_utts[-1] = final_utts[-1] + cur_utt 113 | else: 114 | final_utts.append(cur_utt) 115 | 116 | return final_utts 117 | 118 | 119 | # remove blank between chinese character 120 | def replace_blank(text: str): 121 | out_str = [] 122 | for i, c in enumerate(text): 123 | if c == " ": 124 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 125 | (text[i - 1].isascii() and text[i - 1] != " ")): 126 | out_str.append(c) 127 | else: 128 | out_str.append(c) 129 | return "".join(out_str) 130 | -------------------------------------------------------------------------------- /funaudio_utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/funaudio_utils/__init__.py -------------------------------------------------------------------------------- /funaudio_utils/cosyvoice_plus.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 14:21:08 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-04 16:07:20 7 | Description: file content 8 | ''' 9 | from cosyvoice.cli.cosyvoice import CosyVoice 10 | from cosyvoice.utils.file_utils import logging 11 | from tqdm import tqdm 12 | import time 13 | 14 | class CosyVoicePlus(CosyVoice): 15 | 16 | def inference_zero_shot_with_spkmodel(self,tts_text, spkmodel,stream=False, speed=1.0): 17 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)): 18 | tts_text_token, tts_text_token_len = self.frontend._extract_text_token(tts_text) 19 | spkmodel["text"] = tts_text_token 20 | spkmodel["text_len"] = tts_text_token_len 21 | start_time = time.time() 22 | logging.info('synthesis text {}'.format(i)) 23 | for model_output in self.model.tts(**spkmodel, stream=stream, speed=speed): 24 | speech_len = model_output['tts_speech'].shape[1] / 22050 25 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len)) 26 | yield model_output 27 | start_time = time.time() -------------------------------------------------------------------------------- /funaudio_utils/download_models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 13:54:01 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-04 22:26:22 7 | Description: file content 8 | ''' 9 | import modelscope 10 | import os 11 | import folder_paths 12 | from modelscope import snapshot_download 13 | 14 | # Download the model 15 | base_cosyvoice_model_path = os.path.join(folder_paths.models_dir, "CosyVoice") 16 | base_sensevoice_model_path = os.path.join(folder_paths.models_dir, "SenseVoice") 17 | 18 | def download_cosyvoice_300m(is_25hz=False): 19 | model_name = "CosyVoice-300M" 20 | model_id = "iic/CosyVoice-300M" 21 | if is_25hz: 22 | model_name = "CosyVoice-300M-25Hz" 23 | model_id = "iic/CosyVoice-300M-25Hz" 24 | model_dir = os.path.join(base_cosyvoice_model_path, model_name) 25 | snapshot_download(model_id=model_id, local_dir=model_dir) 26 | return model_name, model_dir 27 | 28 | def download_cosyvoice_300m_sft(is_25hz=False): 29 | model_name = "CosyVoice-300M-SFT" 30 | model_id = "iic/CosyVoice-300M-SFT" 31 | if is_25hz: 32 | model_name = "CosyVoice-300M-SFT-25Hz" 33 | model_id = "MachineS/CosyVoice-300M-SFT-25Hz" 34 | model_dir = os.path.join(base_cosyvoice_model_path, model_name) 35 | snapshot_download(model_id=model_id, local_dir=model_dir) 36 | return model_name, model_dir 37 | 38 | def download_sensevoice_small(): 39 | model_name = "SenseVoiceSmall" 40 | model_id = "iic/SenseVoiceSmall" 41 | model_dir = os.path.join(base_sensevoice_model_path, model_name) 42 | snapshot_download(model_id=model_id, local_dir=model_dir) 43 | return model_name, model_dir 44 | 45 | def download_cosyvoice_300m_instruct(): 46 | model_name = "CosyVoice-300M-Instruct" 47 | model_id = "iic/CosyVoice-300M-Instruct" 48 | model_dir = os.path.join(base_cosyvoice_model_path, model_name) 49 | snapshot_download(model_id=model_id, local_dir=model_dir) 50 | return model_name, model_dir 51 | 52 | def get_speaker_default_path(): 53 | return os.path.join(base_cosyvoice_model_path, "Speaker") -------------------------------------------------------------------------------- /funaudio_utils/pre.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 13:22:31 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-04 14:09:40 7 | Description: file content 8 | ''' 9 | import librosa 10 | import torch 11 | import torchaudio 12 | 13 | class FunAudioLLMTool: 14 | def __init__(self): 15 | self.max_val = 0.8 16 | self.prompt_sr, self.target_sr = 16000, 22050 17 | 18 | def postprocess(self,speech, top_db=60, hop_length=220, win_length=440): 19 | speech, _ = librosa.effects.trim( 20 | speech, top_db=top_db, 21 | frame_length=win_length, 22 | hop_length=hop_length 23 | ) 24 | if speech.abs().max() > self.max_val: 25 | speech = speech / speech.abs().max() * self.max_val 26 | speech = torch.concat([speech, torch.zeros(1, int(self.target_sr * 0.2))], dim=1) 27 | return speech 28 | 29 | def audio_resample(self, waveform, source_sr): 30 | waveform = waveform.squeeze(0) 31 | speech = waveform.mean(dim=0,keepdim=True) 32 | if source_sr != self.prompt_sr: 33 | speech = torchaudio.transforms.Resample(orig_freq=source_sr, new_freq=self.prompt_sr)(speech) 34 | return speech 35 | -------------------------------------------------------------------------------- /matcha/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.5.1 2 | -------------------------------------------------------------------------------- /matcha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/__init__.py -------------------------------------------------------------------------------- /matcha/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/data/__init__.py -------------------------------------------------------------------------------- /matcha/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/data/components/__init__.py -------------------------------------------------------------------------------- /matcha/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /matcha/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /matcha/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/hifigan/__init__.py -------------------------------------------------------------------------------- /matcha/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /matcha/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /matcha/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /matcha/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /matcha/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /matcha/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/models/__init__.py -------------------------------------------------------------------------------- /matcha/models/baselightningmodule.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a base lightning module that can be used to train a model. 3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models. 4 | """ 5 | import inspect 6 | from abc import ABC 7 | from typing import Any, Dict 8 | 9 | import torch 10 | from lightning import LightningModule 11 | from lightning.pytorch.utilities import grad_norm 12 | 13 | from matcha import utils 14 | from matcha.utils.utils import plot_tensor 15 | 16 | log = utils.get_pylogger(__name__) 17 | 18 | 19 | class BaseLightningClass(LightningModule, ABC): 20 | def update_data_statistics(self, data_statistics): 21 | if data_statistics is None: 22 | data_statistics = { 23 | "mel_mean": 0.0, 24 | "mel_std": 1.0, 25 | } 26 | 27 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"])) 28 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"])) 29 | 30 | def configure_optimizers(self) -> Any: 31 | optimizer = self.hparams.optimizer(params=self.parameters()) 32 | if self.hparams.scheduler not in (None, {}): 33 | scheduler_args = {} 34 | # Manage last epoch for exponential schedulers 35 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters: 36 | if hasattr(self, "ckpt_loaded_epoch"): 37 | current_epoch = self.ckpt_loaded_epoch - 1 38 | else: 39 | current_epoch = -1 40 | 41 | scheduler_args.update({"optimizer": optimizer}) 42 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args) 43 | scheduler.last_epoch = current_epoch 44 | return { 45 | "optimizer": optimizer, 46 | "lr_scheduler": { 47 | "scheduler": scheduler, 48 | "interval": self.hparams.scheduler.lightning_args.interval, 49 | "frequency": self.hparams.scheduler.lightning_args.frequency, 50 | "name": "learning_rate", 51 | }, 52 | } 53 | 54 | return {"optimizer": optimizer} 55 | 56 | def get_losses(self, batch): 57 | x, x_lengths = batch["x"], batch["x_lengths"] 58 | y, y_lengths = batch["y"], batch["y_lengths"] 59 | spks = batch["spks"] 60 | 61 | dur_loss, prior_loss, diff_loss = self( 62 | x=x, 63 | x_lengths=x_lengths, 64 | y=y, 65 | y_lengths=y_lengths, 66 | spks=spks, 67 | out_size=self.out_size, 68 | ) 69 | return { 70 | "dur_loss": dur_loss, 71 | "prior_loss": prior_loss, 72 | "diff_loss": diff_loss, 73 | } 74 | 75 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: 76 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init 77 | 78 | def training_step(self, batch: Any, batch_idx: int): 79 | loss_dict = self.get_losses(batch) 80 | self.log( 81 | "step", 82 | float(self.global_step), 83 | on_step=True, 84 | prog_bar=True, 85 | logger=True, 86 | sync_dist=True, 87 | ) 88 | 89 | self.log( 90 | "sub_loss/train_dur_loss", 91 | loss_dict["dur_loss"], 92 | on_step=True, 93 | on_epoch=True, 94 | logger=True, 95 | sync_dist=True, 96 | ) 97 | self.log( 98 | "sub_loss/train_prior_loss", 99 | loss_dict["prior_loss"], 100 | on_step=True, 101 | on_epoch=True, 102 | logger=True, 103 | sync_dist=True, 104 | ) 105 | self.log( 106 | "sub_loss/train_diff_loss", 107 | loss_dict["diff_loss"], 108 | on_step=True, 109 | on_epoch=True, 110 | logger=True, 111 | sync_dist=True, 112 | ) 113 | 114 | total_loss = sum(loss_dict.values()) 115 | self.log( 116 | "loss/train", 117 | total_loss, 118 | on_step=True, 119 | on_epoch=True, 120 | logger=True, 121 | prog_bar=True, 122 | sync_dist=True, 123 | ) 124 | 125 | return {"loss": total_loss, "log": loss_dict} 126 | 127 | def validation_step(self, batch: Any, batch_idx: int): 128 | loss_dict = self.get_losses(batch) 129 | self.log( 130 | "sub_loss/val_dur_loss", 131 | loss_dict["dur_loss"], 132 | on_step=True, 133 | on_epoch=True, 134 | logger=True, 135 | sync_dist=True, 136 | ) 137 | self.log( 138 | "sub_loss/val_prior_loss", 139 | loss_dict["prior_loss"], 140 | on_step=True, 141 | on_epoch=True, 142 | logger=True, 143 | sync_dist=True, 144 | ) 145 | self.log( 146 | "sub_loss/val_diff_loss", 147 | loss_dict["diff_loss"], 148 | on_step=True, 149 | on_epoch=True, 150 | logger=True, 151 | sync_dist=True, 152 | ) 153 | 154 | total_loss = sum(loss_dict.values()) 155 | self.log( 156 | "loss/val", 157 | total_loss, 158 | on_step=True, 159 | on_epoch=True, 160 | logger=True, 161 | prog_bar=True, 162 | sync_dist=True, 163 | ) 164 | 165 | return total_loss 166 | 167 | def on_validation_end(self) -> None: 168 | if self.trainer.is_global_zero: 169 | one_batch = next(iter(self.trainer.val_dataloaders)) 170 | if self.current_epoch == 0: 171 | log.debug("Plotting original samples") 172 | for i in range(2): 173 | y = one_batch["y"][i].unsqueeze(0).to(self.device) 174 | self.logger.experiment.add_image( 175 | f"original/{i}", 176 | plot_tensor(y.squeeze().cpu()), 177 | self.current_epoch, 178 | dataformats="HWC", 179 | ) 180 | 181 | log.debug("Synthesising...") 182 | for i in range(2): 183 | x = one_batch["x"][i].unsqueeze(0).to(self.device) 184 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device) 185 | spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None 186 | output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks) 187 | y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"] 188 | attn = output["attn"] 189 | self.logger.experiment.add_image( 190 | f"generated_enc/{i}", 191 | plot_tensor(y_enc.squeeze().cpu()), 192 | self.current_epoch, 193 | dataformats="HWC", 194 | ) 195 | self.logger.experiment.add_image( 196 | f"generated_dec/{i}", 197 | plot_tensor(y_dec.squeeze().cpu()), 198 | self.current_epoch, 199 | dataformats="HWC", 200 | ) 201 | self.logger.experiment.add_image( 202 | f"alignment/{i}", 203 | plot_tensor(attn.squeeze().cpu()), 204 | self.current_epoch, 205 | dataformats="HWC", 206 | ) 207 | 208 | def on_before_optimizer_step(self, optimizer): 209 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()}) 210 | -------------------------------------------------------------------------------- /matcha/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/models/components/__init__.py -------------------------------------------------------------------------------- /matcha/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from matcha.models.components.decoder import Decoder 7 | from matcha.utils.pylogger import get_pylogger 8 | 9 | log = get_pylogger(__name__) 10 | 11 | 12 | class BASECFM(torch.nn.Module, ABC): 13 | def __init__( 14 | self, 15 | n_feats, 16 | cfm_params, 17 | n_spks=1, 18 | spk_emb_dim=128, 19 | ): 20 | super().__init__() 21 | self.n_feats = n_feats 22 | self.n_spks = n_spks 23 | self.spk_emb_dim = spk_emb_dim 24 | self.solver = cfm_params.solver 25 | if hasattr(cfm_params, "sigma_min"): 26 | self.sigma_min = cfm_params.sigma_min 27 | else: 28 | self.sigma_min = 1e-4 29 | 30 | self.estimator = None 31 | 32 | @torch.inference_mode() 33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 34 | """Forward diffusion 35 | 36 | Args: 37 | mu (torch.Tensor): output of encoder 38 | shape: (batch_size, n_feats, mel_timesteps) 39 | mask (torch.Tensor): output_mask 40 | shape: (batch_size, 1, mel_timesteps) 41 | n_timesteps (int): number of diffusion steps 42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 43 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 44 | shape: (batch_size, spk_emb_dim) 45 | cond: Not used but kept for future purposes 46 | 47 | Returns: 48 | sample: generated mel-spectrogram 49 | shape: (batch_size, n_feats, mel_timesteps) 50 | """ 51 | z = torch.randn_like(mu) * temperature 52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 54 | 55 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 56 | """ 57 | Fixed euler solver for ODEs. 58 | Args: 59 | x (torch.Tensor): random noise 60 | t_span (torch.Tensor): n_timesteps interpolated 61 | shape: (n_timesteps + 1,) 62 | mu (torch.Tensor): output of encoder 63 | shape: (batch_size, n_feats, mel_timesteps) 64 | mask (torch.Tensor): output_mask 65 | shape: (batch_size, 1, mel_timesteps) 66 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 67 | shape: (batch_size, spk_emb_dim) 68 | cond: Not used but kept for future purposes 69 | """ 70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 71 | 72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 73 | # Or in future might add like a return_all_steps flag 74 | sol = [] 75 | 76 | for step in range(1, len(t_span)): 77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 78 | 79 | x = x + dt * dphi_dt 80 | t = t + dt 81 | sol.append(x) 82 | if step < len(t_span) - 1: 83 | dt = t_span[step + 1] - t 84 | 85 | return sol[-1] 86 | 87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 88 | """Computes diffusion loss 89 | 90 | Args: 91 | x1 (torch.Tensor): Target 92 | shape: (batch_size, n_feats, mel_timesteps) 93 | mask (torch.Tensor): target mask 94 | shape: (batch_size, 1, mel_timesteps) 95 | mu (torch.Tensor): output of encoder 96 | shape: (batch_size, n_feats, mel_timesteps) 97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 98 | shape: (batch_size, spk_emb_dim) 99 | 100 | Returns: 101 | loss: conditional flow matching loss 102 | y: conditional flow 103 | shape: (batch_size, n_feats, mel_timesteps) 104 | """ 105 | b, _, t = mu.shape 106 | 107 | # random timestep 108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 109 | # sample noise p(x_0) 110 | z = torch.randn_like(x1) 111 | 112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 113 | u = x1 - (1 - self.sigma_min) * z 114 | 115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 116 | torch.sum(mask) * u.shape[1] 117 | ) 118 | return loss, y 119 | 120 | 121 | class CFM(BASECFM): 122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 123 | super().__init__( 124 | n_feats=in_channels, 125 | cfm_params=cfm_params, 126 | n_spks=n_spks, 127 | spk_emb_dim=spk_emb_dim, 128 | ) 129 | 130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 131 | # Just change the architecture of the estimator here 132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 133 | -------------------------------------------------------------------------------- /matcha/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/onnx/__init__.py -------------------------------------------------------------------------------- /matcha/onnx/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from lightning import LightningModule 8 | 9 | from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder 10 | 11 | DEFAULT_OPSET = 15 12 | 13 | SEED = 1234 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | torch.cuda.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class MatchaWithVocoder(LightningModule): 23 | def __init__(self, matcha, vocoder): 24 | super().__init__() 25 | self.matcha = matcha 26 | self.vocoder = vocoder 27 | 28 | def forward(self, x, x_lengths, scales, spks=None): 29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) 30 | wavs = self.vocoder(mel).clamp(-1, 1) 31 | lengths = mel_lengths * 256 32 | return wavs.squeeze(1), lengths 33 | 34 | 35 | def get_exportable_module(matcha, vocoder, n_timesteps): 36 | """ 37 | Return an appropriate `LighteningModule` and output-node names 38 | based on whether the vocoder is embedded in the final graph 39 | """ 40 | 41 | def onnx_forward_func(x, x_lengths, scales, spks=None): 42 | """ 43 | Custom forward function for accepting 44 | scaler parameters as tensors 45 | """ 46 | # Extract scaler parameters from tensors 47 | temperature = scales[0] 48 | length_scale = scales[1] 49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) 50 | return output["mel"], output["mel_lengths"] 51 | 52 | # Monkey-patch Matcha's forward function 53 | matcha.forward = onnx_forward_func 54 | 55 | if vocoder is None: 56 | model, output_names = matcha, ["mel", "mel_lengths"] 57 | else: 58 | model = MatchaWithVocoder(matcha, vocoder) 59 | output_names = ["wav", "wav_lengths"] 60 | return model, output_names 61 | 62 | 63 | def get_inputs(is_multi_speaker): 64 | """ 65 | Create dummy inputs for tracing 66 | """ 67 | dummy_input_length = 50 68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 69 | x_lengths = torch.LongTensor([dummy_input_length]) 70 | 71 | # Scales 72 | temperature = 0.667 73 | length_scale = 1.0 74 | scales = torch.Tensor([temperature, length_scale]) 75 | 76 | model_inputs = [x, x_lengths, scales] 77 | input_names = [ 78 | "x", 79 | "x_lengths", 80 | "scales", 81 | ] 82 | 83 | if is_multi_speaker: 84 | spks = torch.LongTensor([1]) 85 | model_inputs.append(spks) 86 | input_names.append("spks") 87 | 88 | return tuple(model_inputs), input_names 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") 93 | 94 | parser.add_argument( 95 | "checkpoint_path", 96 | type=str, 97 | help="Path to the model checkpoint", 98 | ) 99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file") 100 | parser.add_argument( 101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 102 | ) 103 | parser.add_argument( 104 | "--vocoder-name", 105 | type=str, 106 | choices=list(VOCODER_URLS.keys()), 107 | default=None, 108 | help="Name of the vocoder to embed in the ONNX graph", 109 | ) 110 | parser.add_argument( 111 | "--vocoder-checkpoint-path", 112 | type=str, 113 | default=None, 114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 115 | ) 116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 117 | 118 | args = parser.parse_args() 119 | 120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") 121 | print(f"Setting n_timesteps to {args.n_timesteps}") 122 | 123 | checkpoint_path = Path(args.checkpoint_path) 124 | matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") 125 | 126 | if args.vocoder_name or args.vocoder_checkpoint_path: 127 | assert ( 128 | args.vocoder_name and args.vocoder_checkpoint_path 129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 131 | else: 132 | vocoder = None 133 | 134 | is_multi_speaker = matcha.n_spks > 1 135 | 136 | dummy_input, input_names = get_inputs(is_multi_speaker) 137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) 138 | 139 | # Set dynamic shape for inputs/outputs 140 | dynamic_axes = { 141 | "x": {0: "batch_size", 1: "time"}, 142 | "x_lengths": {0: "batch_size"}, 143 | } 144 | 145 | if vocoder is None: 146 | dynamic_axes.update( 147 | { 148 | "mel": {0: "batch_size", 2: "time"}, 149 | "mel_lengths": {0: "batch_size"}, 150 | } 151 | ) 152 | else: 153 | print("Embedding the vocoder in the ONNX graph") 154 | dynamic_axes.update( 155 | { 156 | "wav": {0: "batch_size", 1: "time"}, 157 | "wav_lengths": {0: "batch_size"}, 158 | } 159 | ) 160 | 161 | if is_multi_speaker: 162 | dynamic_axes["spks"] = {0: "batch_size"} 163 | 164 | # Create the output directory (if not exists) 165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 166 | 167 | model.to_onnx( 168 | args.output, 169 | dummy_input, 170 | input_names=input_names, 171 | output_names=output_names, 172 | dynamic_axes=dynamic_axes, 173 | opset_version=args.opset, 174 | export_params=True, 175 | do_constant_folding=True, 176 | ) 177 | print(f"[🍵] ONNX model exported to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /matcha/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | from time import perf_counter 6 | 7 | import numpy as np 8 | import onnxruntime as ort 9 | import soundfile as sf 10 | import torch 11 | 12 | from matcha.cli import plot_spectrogram_to_numpy, process_text 13 | 14 | 15 | def validate_args(args): 16 | assert ( 17 | args.text or args.file 18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." 19 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 21 | return args 22 | 23 | 24 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 25 | if external_vocoder is None: 26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 27 | t0 = perf_counter() 28 | wavs, wav_lengths = model.run(None, inputs) 29 | infer_secs = perf_counter() - t0 30 | mel_infer_secs = vocoder_infer_secs = None 31 | else: 32 | print("[🍵] Generating mel using Matcha") 33 | mel_t0 = perf_counter() 34 | mels, mel_lengths = model.run(None, inputs) 35 | mel_infer_secs = perf_counter() - mel_t0 36 | print("Generating waveform from mel using external vocoder") 37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 38 | vocoder_t0 = perf_counter() 39 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 40 | vocoder_infer_secs = perf_counter() - vocoder_t0 41 | wavs = wavs.squeeze(1) 42 | wav_lengths = mel_lengths * 256 43 | infer_secs = mel_infer_secs + vocoder_infer_secs 44 | 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(parents=True, exist_ok=True) 47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 49 | audio = wav[:wav_length] 50 | print(f"Writing audio to {output_filename}") 51 | sf.write(output_filename, audio, 22050, "PCM_24") 52 | 53 | wav_secs = wav_lengths.sum() / 22050 54 | print(f"Inference seconds: {infer_secs}") 55 | print(f"Generated wav seconds: {wav_secs}") 56 | rtf = infer_secs / wav_secs 57 | if mel_infer_secs is not None: 58 | mel_rtf = mel_infer_secs / wav_secs 59 | print(f"Matcha RTF: {mel_rtf}") 60 | if vocoder_infer_secs is not None: 61 | vocoder_rtf = vocoder_infer_secs / wav_secs 62 | print(f"Vocoder RTF: {vocoder_rtf}") 63 | print(f"Overall RTF: {rtf}") 64 | 65 | 66 | def write_mels(model, inputs, output_dir): 67 | t0 = perf_counter() 68 | mels, mel_lengths = model.run(None, inputs) 69 | infer_secs = perf_counter() - t0 70 | 71 | output_dir = Path(output_dir) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | for i, mel in enumerate(mels): 74 | output_stem = output_dir.joinpath(f"output_{i + 1}") 75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 76 | np.save(output_stem.with_suffix(".numpy"), mel) 77 | 78 | wav_secs = (mel_lengths * 256).sum() / 22050 79 | print(f"Inference seconds: {infer_secs}") 80 | print(f"Generated wav seconds: {wav_secs}") 81 | rtf = infer_secs / wav_secs 82 | print(f"RTF: {rtf}") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser( 87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" 88 | ) 89 | parser.add_argument( 90 | "model", 91 | type=str, 92 | help="ONNX model to use", 93 | ) 94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID") 98 | parser.add_argument( 99 | "--temperature", 100 | type=float, 101 | default=0.667, 102 | help="Variance of the x0 noise (default: 0.667)", 103 | ) 104 | parser.add_argument( 105 | "--speaking-rate", 106 | type=float, 107 | default=1.0, 108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 109 | ) 110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 111 | parser.add_argument( 112 | "--output-dir", 113 | type=str, 114 | default=os.getcwd(), 115 | help="Output folder to save results (default: current dir)", 116 | ) 117 | 118 | args = parser.parse_args() 119 | args = validate_args(args) 120 | 121 | if args.gpu: 122 | providers = ["GPUExecutionProvider"] 123 | else: 124 | providers = ["CPUExecutionProvider"] 125 | model = ort.InferenceSession(args.model, providers=providers) 126 | 127 | model_inputs = model.get_inputs() 128 | model_outputs = list(model.get_outputs()) 129 | 130 | if args.text: 131 | text_lines = args.text.splitlines() 132 | else: 133 | with open(args.file, encoding="utf-8") as file: 134 | text_lines = file.read().splitlines() 135 | 136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 137 | x = [line["x"].squeeze() for line in processed_lines] 138 | # Pad 139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 140 | x = x.detach().cpu().numpy() 141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 142 | inputs = { 143 | "x": x, 144 | "x_lengths": x_lengths, 145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 146 | } 147 | is_multi_speaker = len(model_inputs) == 4 148 | if is_multi_speaker: 149 | if args.spk is None: 150 | args.spk = 0 151 | warn = "[!] Speaker ID not provided! Using speaker ID 0" 152 | warnings.warn(warn, UserWarning) 153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) 154 | 155 | has_vocoder_embedded = model_outputs[0].name == "wav" 156 | if has_vocoder_embedded: 157 | write_wavs(model, inputs, args.output_dir) 158 | elif args.vocoder: 159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 161 | else: 162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 163 | warnings.warn(warn, UserWarning) 164 | write_mels(model, inputs, args.output_dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /matcha/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from matcha.text import cleaners 3 | from matcha.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /matcha/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | import piper_phonemize 19 | from unidecode import unidecode 20 | 21 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 22 | critical_logger = logging.getLogger("phonemizer") 23 | critical_logger.setLevel(logging.CRITICAL) 24 | 25 | # Intializing the phonemizer globally significantly reduces the speed 26 | # now the phonemizer is not initialising at every call 27 | # Might be less flexible, but it is much-much faster 28 | global_phonemizer = phonemizer.backend.EspeakBackend( 29 | language="en-us", 30 | preserve_punctuation=True, 31 | with_stress=True, 32 | language_switch="remove-flags", 33 | logger=critical_logger, 34 | ) 35 | 36 | 37 | # Regular expression matching whitespace: 38 | _whitespace_re = re.compile(r"\s+") 39 | 40 | # List of (regular expression, replacement) pairs for abbreviations: 41 | _abbreviations = [ 42 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 43 | for x in [ 44 | ("mrs", "misess"), 45 | ("mr", "mister"), 46 | ("dr", "doctor"), 47 | ("st", "saint"), 48 | ("co", "company"), 49 | ("jr", "junior"), 50 | ("maj", "major"), 51 | ("gen", "general"), 52 | ("drs", "doctors"), 53 | ("rev", "reverend"), 54 | ("lt", "lieutenant"), 55 | ("hon", "honorable"), 56 | ("sgt", "sergeant"), 57 | ("capt", "captain"), 58 | ("esq", "esquire"), 59 | ("ltd", "limited"), 60 | ("col", "colonel"), 61 | ("ft", "fort"), 62 | ] 63 | ] 64 | 65 | 66 | def expand_abbreviations(text): 67 | for regex, replacement in _abbreviations: 68 | text = re.sub(regex, replacement, text) 69 | return text 70 | 71 | 72 | def lowercase(text): 73 | return text.lower() 74 | 75 | 76 | def collapse_whitespace(text): 77 | return re.sub(_whitespace_re, " ", text) 78 | 79 | 80 | def convert_to_ascii(text): 81 | return unidecode(text) 82 | 83 | 84 | def basic_cleaners(text): 85 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 86 | text = lowercase(text) 87 | text = collapse_whitespace(text) 88 | return text 89 | 90 | 91 | def transliteration_cleaners(text): 92 | """Pipeline for non-English text that transliterates to ASCII.""" 93 | text = convert_to_ascii(text) 94 | text = lowercase(text) 95 | text = collapse_whitespace(text) 96 | return text 97 | 98 | 99 | def english_cleaners2(text): 100 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 101 | text = convert_to_ascii(text) 102 | text = lowercase(text) 103 | text = expand_abbreviations(text) 104 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 105 | phonemes = collapse_whitespace(phonemes) 106 | return phonemes 107 | 108 | 109 | def english_cleaners_piper(text): 110 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 111 | text = convert_to_ascii(text) 112 | text = lowercase(text) 113 | text = expand_abbreviations(text) 114 | phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) 115 | phonemes = collapse_whitespace(phonemes) 116 | return phonemes 117 | -------------------------------------------------------------------------------- /matcha/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /matcha/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /matcha/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from matcha import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /matcha/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from matcha.utils.logging_utils import log_hyperparameters 3 | from matcha.utils.pylogger import get_pylogger 4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree 5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /matcha/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /matcha/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import rootutils 14 | import torch 15 | from hydra import compose, initialize 16 | from omegaconf import open_dict 17 | from tqdm.auto import tqdm 18 | 19 | from matcha.data.text_mel_datamodule import TextMelDataModule 20 | from matcha.utils.logging_utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 26 | """Generate data mean and standard deviation helpful in data normalisation 27 | 28 | Args: 29 | data_loader (torch.utils.data.Dataloader): _description_ 30 | out_channels (int): mel spectrogram channels 31 | """ 32 | total_mel_sum = 0 33 | total_mel_sq_sum = 0 34 | total_mel_len = 0 35 | 36 | for batch in tqdm(data_loader, leave=False): 37 | mels = batch["y"] 38 | mel_lengths = batch["y_lengths"] 39 | 40 | total_mel_len += torch.sum(mel_lengths) 41 | total_mel_sum += torch.sum(mels) 42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 43 | 44 | data_mean = total_mel_sum / (total_mel_len * out_channels) 45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 46 | 47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument( 54 | "-i", 55 | "--input-config", 56 | type=str, 57 | default="vctk.yaml", 58 | help="The name of the yaml config file under configs/data", 59 | ) 60 | 61 | parser.add_argument( 62 | "-b", 63 | "--batch-size", 64 | type=int, 65 | default="256", 66 | help="Can have increased batch size for faster computation", 67 | ) 68 | 69 | parser.add_argument( 70 | "-f", 71 | "--force", 72 | action="store_true", 73 | default=False, 74 | required=False, 75 | help="force overwrite the file", 76 | ) 77 | args = parser.parse_args() 78 | output_file = Path(args.input_config).with_suffix(".json") 79 | 80 | if os.path.exists(output_file) and not args.force: 81 | print("File already exists. Use -f to force overwrite") 82 | sys.exit(1) 83 | 84 | with initialize(version_base="1.3", config_path="../../configs/data"): 85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 86 | 87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 88 | 89 | with open_dict(cfg): 90 | del cfg["hydra"] 91 | del cfg["_target_"] 92 | cfg["data_statistics"] = None 93 | cfg["seed"] = 1234 94 | cfg["batch_size"] = args.batch_size 95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 97 | 98 | text_mel_datamodule = TextMelDataModule(**cfg) 99 | text_mel_datamodule.setup() 100 | data_loader = text_mel_datamodule.train_dataloader() 101 | log.info("Dataloader loaded! Now computing stats...") 102 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 103 | print(params) 104 | json.dump( 105 | params, 106 | open(output_file, "w"), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /matcha/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from matcha.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /matcha/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from matcha.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /matcha/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths): 45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 46 | return loss 47 | 48 | 49 | def normalize(data, mu, std): 50 | if not isinstance(mu, (float, int)): 51 | if isinstance(mu, list): 52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 53 | elif isinstance(mu, torch.Tensor): 54 | mu = mu.to(data.device) 55 | elif isinstance(mu, np.ndarray): 56 | mu = torch.from_numpy(mu).to(data.device) 57 | mu = mu.unsqueeze(-1) 58 | 59 | if not isinstance(std, (float, int)): 60 | if isinstance(std, list): 61 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 62 | elif isinstance(std, torch.Tensor): 63 | std = std.to(data.device) 64 | elif isinstance(std, np.ndarray): 65 | std = torch.from_numpy(std).to(data.device) 66 | std = std.unsqueeze(-1) 67 | 68 | return (data - mu) / std 69 | 70 | 71 | def denormalize(data, mu, std): 72 | if not isinstance(mu, float): 73 | if isinstance(mu, list): 74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 75 | elif isinstance(mu, torch.Tensor): 76 | mu = mu.to(data.device) 77 | elif isinstance(mu, np.ndarray): 78 | mu = torch.from_numpy(mu).to(data.device) 79 | mu = mu.unsqueeze(-1) 80 | 81 | if not isinstance(std, float): 82 | if isinstance(std, list): 83 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 84 | elif isinstance(std, torch.Tensor): 85 | std = std.to(data.device) 86 | elif isinstance(std, np.ndarray): 87 | std = torch.from_numpy(std).to(data.device) 88 | std = std.unsqueeze(-1) 89 | 90 | return data * std + mu 91 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matcha.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /matcha/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /matcha/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /matcha/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from matcha.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /matcha/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import warnings 4 | from importlib.util import find_spec 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Tuple 7 | 8 | import gdown 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | import wget 13 | from omegaconf import DictConfig 14 | 15 | from matcha.utils import pylogger, rich_utils 16 | 17 | log = pylogger.get_pylogger(__name__) 18 | 19 | 20 | def extras(cfg: DictConfig) -> None: 21 | """Applies optional utilities before the task is started. 22 | 23 | Utilities: 24 | - Ignoring python warnings 25 | - Setting tags from command line 26 | - Rich config printing 27 | 28 | :param cfg: A DictConfig object containing the config tree. 29 | """ 30 | # return if no `extras` config 31 | if not cfg.get("extras"): 32 | log.warning("Extras config not found! ") 33 | return 34 | 35 | # disable python warnings 36 | if cfg.extras.get("ignore_warnings"): 37 | log.info("Disabling python warnings! ") 38 | warnings.filterwarnings("ignore") 39 | 40 | # prompt user to input tags from command line if none are provided in the config 41 | if cfg.extras.get("enforce_tags"): 42 | log.info("Enforcing tags! ") 43 | rich_utils.enforce_tags(cfg, save_to_file=True) 44 | 45 | # pretty print config tree using Rich library 46 | if cfg.extras.get("print_config"): 47 | log.info("Printing config tree with Rich! ") 48 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) 49 | 50 | 51 | def task_wrapper(task_func: Callable) -> Callable: 52 | """Optional decorator that controls the failure behavior when executing the task function. 53 | 54 | This wrapper can be used to: 55 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) 56 | - save the exception to a `.log` file 57 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) 58 | - etc. (adjust depending on your needs) 59 | 60 | Example: 61 | ``` 62 | @utils.task_wrapper 63 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 64 | ... 65 | return metric_dict, object_dict 66 | ``` 67 | 68 | :param task_func: The task function to be wrapped. 69 | 70 | :return: The wrapped task function. 71 | """ 72 | 73 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 74 | # execute the task 75 | try: 76 | metric_dict, object_dict = task_func(cfg=cfg) 77 | 78 | # things to do if exception occurs 79 | except Exception as ex: 80 | # save exception to `.log` file 81 | log.exception("") 82 | 83 | # some hyperparameter combinations might be invalid or cause out-of-memory errors 84 | # so when using hparam search plugins like Optuna, you might want to disable 85 | # raising the below exception to avoid multirun failure 86 | raise ex 87 | 88 | # things to always do after either success or exception 89 | finally: 90 | # display output dir path in terminal 91 | log.info(f"Output dir: {cfg.paths.output_dir}") 92 | 93 | # always close wandb run (even if exception occurs so multirun won't fail) 94 | if find_spec("wandb"): # check if wandb is installed 95 | import wandb 96 | 97 | if wandb.run: 98 | log.info("Closing wandb!") 99 | wandb.finish() 100 | 101 | return metric_dict, object_dict 102 | 103 | return wrap 104 | 105 | 106 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: 107 | """Safely retrieves value of the metric logged in LightningModule. 108 | 109 | :param metric_dict: A dict containing metric values. 110 | :param metric_name: The name of the metric to retrieve. 111 | :return: The value of the metric. 112 | """ 113 | if not metric_name: 114 | log.info("Metric name is None! Skipping metric value retrieval...") 115 | return None 116 | 117 | if metric_name not in metric_dict: 118 | raise ValueError( 119 | f"Metric value not found! \n" 120 | "Make sure metric name logged in LightningModule is correct!\n" 121 | "Make sure `optimized_metric` name in `hparams_search` config is correct!" 122 | ) 123 | 124 | metric_value = metric_dict[metric_name].item() 125 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") 126 | 127 | return metric_value 128 | 129 | 130 | def intersperse(lst, item): 131 | # Adds blank symbol 132 | result = [item] * (len(lst) * 2 + 1) 133 | result[1::2] = lst 134 | return result 135 | 136 | 137 | def save_figure_to_numpy(fig): 138 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") 139 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 140 | return data 141 | 142 | 143 | def plot_tensor(tensor): 144 | plt.style.use("default") 145 | fig, ax = plt.subplots(figsize=(12, 3)) 146 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 147 | plt.colorbar(im, ax=ax) 148 | plt.tight_layout() 149 | fig.canvas.draw() 150 | data = save_figure_to_numpy(fig) 151 | plt.close() 152 | return data 153 | 154 | 155 | def save_plot(tensor, savepath): 156 | plt.style.use("default") 157 | fig, ax = plt.subplots(figsize=(12, 3)) 158 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") 159 | plt.colorbar(im, ax=ax) 160 | plt.tight_layout() 161 | fig.canvas.draw() 162 | plt.savefig(savepath) 163 | plt.close() 164 | 165 | 166 | def to_numpy(tensor): 167 | if isinstance(tensor, np.ndarray): 168 | return tensor 169 | elif isinstance(tensor, torch.Tensor): 170 | return tensor.detach().cpu().numpy() 171 | elif isinstance(tensor, list): 172 | return np.array(tensor) 173 | else: 174 | raise TypeError("Unsupported type for conversion to numpy array") 175 | 176 | 177 | def get_user_data_dir(appname="matcha_tts"): 178 | """ 179 | Args: 180 | appname (str): Name of application 181 | 182 | Returns: 183 | Path: path to user data directory 184 | """ 185 | 186 | MATCHA_HOME = os.environ.get("MATCHA_HOME") 187 | if MATCHA_HOME is not None: 188 | ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) 189 | elif sys.platform == "win32": 190 | import winreg # pylint: disable=import-outside-toplevel 191 | 192 | key = winreg.OpenKey( 193 | winreg.HKEY_CURRENT_USER, 194 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", 195 | ) 196 | dir_, _ = winreg.QueryValueEx(key, "Local AppData") 197 | ans = Path(dir_).resolve(strict=False) 198 | elif sys.platform == "darwin": 199 | ans = Path("~/Library/Application Support/").expanduser() 200 | else: 201 | ans = Path.home().joinpath(".local/share") 202 | 203 | final_path = ans.joinpath(appname) 204 | final_path.mkdir(parents=True, exist_ok=True) 205 | return final_path 206 | 207 | 208 | def assert_model_downloaded(checkpoint_path, url, use_wget=True): 209 | if Path(checkpoint_path).exists(): 210 | log.debug(f"[+] Model already present at {checkpoint_path}!") 211 | print(f"[+] Model already present at {checkpoint_path}!") 212 | return 213 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it") 214 | print(f"[-] Model not found at {checkpoint_path}! Will download it") 215 | checkpoint_path = str(checkpoint_path) 216 | if not use_wget: 217 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) 218 | else: 219 | wget.download(url=url, out=checkpoint_path) 220 | -------------------------------------------------------------------------------- /nodes/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | now_dir = os.path.dirname(os.path.abspath(__file__)) 4 | node_root = os.path.dirname(now_dir) 5 | sys.path.append(node_root) -------------------------------------------------------------------------------- /nodes/sensevoice_nodes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: SpenserCai 3 | Date: 2024-10-04 12:13:43 4 | version: 5 | LastEditors: SpenserCai 6 | LastEditTime: 2024-10-05 11:03:31 7 | Description: file content 8 | ''' 9 | import folder_paths 10 | import os 11 | import numpy as np 12 | from funasr import AutoModel 13 | from funaudio_utils.pre import FunAudioLLMTool 14 | from funaudio_utils.download_models import download_sensevoice_small 15 | from funasr.utils import postprocess_utils 16 | 17 | fAudioTool = FunAudioLLMTool() 18 | 19 | CATEGORY_NAME = "FunAudioLLM - SenseVoice" 20 | 21 | folder_paths.add_model_folder_path("SenseVoice", os.path.join(folder_paths.models_dir, "SenseVoice")) 22 | 23 | def patch_emoji(emoji_dict): 24 | t_emoji_dict_key = emoji_dict.keys() 25 | emoji_dict_new = {} 26 | for t_e_k in t_emoji_dict_key: 27 | emoji_dict_new[t_e_k.lower()] = emoji_dict[t_e_k] 28 | emoji_dict.update(emoji_dict_new) 29 | return emoji_dict 30 | 31 | class SenseVoiceNode: 32 | @classmethod 33 | def INPUT_TYPES(s): 34 | return { 35 | "required":{ 36 | "audio":("AUDIO",), 37 | "use_fast_mode":("BOOLEAN",{ 38 | "default": False 39 | }), 40 | "punc_segment":("BOOLEAN",{ 41 | "default": False 42 | }), 43 | } 44 | } 45 | 46 | CATEGORY = CATEGORY_NAME 47 | RETURN_TYPES = ("STRING",) 48 | 49 | FUNCTION="generate" 50 | 51 | def generate(self,audio, use_fast_mode,punc_segment): 52 | sensevoice_code_path = os.path.join(folder_paths.base_path,"custom_nodes/ComfyUI-FunAudioLLM/sensevoice/model.py") 53 | speech = audio["waveform"] 54 | source_sr = audio["sample_rate"] 55 | speech = fAudioTool.audio_resample(speech, source_sr) 56 | speech = fAudioTool.postprocess(speech) 57 | # 判断语音长度是否大于30s 58 | if speech.shape[1] > 30 * 22050 and use_fast_mode: 59 | raise ValueError("Audio length is too long, please set use_fast_mode to False.") 60 | _, model_dir = download_sensevoice_small() 61 | model_arg = { 62 | "input":speech[0], 63 | "cache":{}, 64 | "language":"auto", 65 | "batch_size_s":60, 66 | } 67 | model_use_arg = { 68 | "model":model_dir, 69 | "trust_remote_code":True, 70 | "remote_code":sensevoice_code_path, 71 | "device":"cuda:0", 72 | } 73 | 74 | if not use_fast_mode: 75 | model_use_arg["vad_model"] = "fsmn-vad" 76 | model_use_arg["vad_kwargs"] = {"max_single_segment_time":30000} 77 | 78 | model_arg["merge_vad"] = True 79 | model_arg["merge_length_s"] = 15 80 | 81 | if punc_segment: 82 | model_use_arg["punc_model"] = "ct-punc-c" 83 | 84 | model = AutoModel(**model_use_arg) 85 | output = model.generate(**model_arg) 86 | postprocess_utils.emoji_dict = patch_emoji(postprocess_utils.emoji_dict) 87 | return (postprocess_utils.rich_transcription_postprocess(output[0]["text"]),) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-funaudiollm" 3 | description = "Comfyui custom node for [FunAudioLLM](https://funaudiollm.github.io/) include [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) and [SenseVoice](https://github.com/FunAudioLLM/SenseVoice)." 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["conformer", "deepspeed; sys_platform == 'linux'", "diffusers", "grpcio", "grpcio-tools", "hydra-core", "HyperPyYAML", "inflect", "librosa", "lightning", "matplotlib", "modelscope", "networkx", "omegaconf", "onnxruntime-gpu; sys_platform == 'linux'", "onnxruntime; sys_platform == 'darwin' or sys_platform == 'win32'", "openai-whisper", "protobuf", "pydantic", "rich", "soundfile", "tensorboard", "wget", "gdown", "pyarrow", "jieba", "pypinyin", "pydub", "audiosegment", "srt", "ffmpeg-python", "WeTextProcessing", "aliyun-python-sdk-core", "funasr>=1.1.3", "huggingface", "huggingface_hub"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/SpenserCai/ComfyUI-FunAudioLLM" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "spensercai" 14 | DisplayName = "ComfyUI-FunAudioLLM" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | conformer 2 | deepspeed; sys_platform == 'linux' 3 | diffusers 4 | grpcio 5 | grpcio-tools 6 | hydra-core 7 | HyperPyYAML 8 | inflect 9 | librosa 10 | lightning 11 | matplotlib 12 | modelscope 13 | networkx 14 | omegaconf 15 | onnxruntime-gpu; sys_platform == 'linux' 16 | onnxruntime; sys_platform == 'darwin' or sys_platform == 'win32' 17 | openai-whisper 18 | protobuf 19 | pydantic 20 | rich 21 | soundfile 22 | tensorboard 23 | wget 24 | gdown 25 | pyarrow 26 | jieba 27 | pypinyin 28 | pydub 29 | audiosegment 30 | srt 31 | ffmpeg-python 32 | WeTextProcessing 33 | funasr>=1.1.3 34 | huggingface 35 | huggingface_hub 36 | -------------------------------------------------------------------------------- /sensevoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/sensevoice/__init__.py -------------------------------------------------------------------------------- /sensevoice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/sensevoice/utils/__init__.py -------------------------------------------------------------------------------- /sensevoice/utils/export_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | def export( 6 | model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs 7 | ): 8 | model_scripts = model.export(**kwargs) 9 | export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) 10 | os.makedirs(export_dir, exist_ok=True) 11 | 12 | if not isinstance(model_scripts, (list, tuple)): 13 | model_scripts = (model_scripts,) 14 | for m in model_scripts: 15 | m.eval() 16 | if type == "onnx": 17 | _onnx( 18 | m, 19 | quantize=quantize, 20 | opset_version=opset_version, 21 | export_dir=export_dir, 22 | **kwargs, 23 | ) 24 | print("output dir: {}".format(export_dir)) 25 | 26 | return export_dir 27 | 28 | 29 | def _onnx( 30 | model, 31 | quantize: bool = False, 32 | opset_version: int = 14, 33 | export_dir: str = None, 34 | **kwargs, 35 | ): 36 | 37 | dummy_input = model.export_dummy_inputs() 38 | 39 | verbose = kwargs.get("verbose", False) 40 | 41 | export_name = model.export_name() 42 | model_path = os.path.join(export_dir, export_name) 43 | torch.onnx.export( 44 | model, 45 | dummy_input, 46 | model_path, 47 | verbose=verbose, 48 | opset_version=opset_version, 49 | input_names=model.export_input_names(), 50 | output_names=model.export_output_names(), 51 | dynamic_axes=model.export_dynamic_axes(), 52 | ) 53 | 54 | if quantize: 55 | from onnxruntime.quantization import QuantType, quantize_dynamic 56 | import onnx 57 | 58 | quant_model_path = model_path.replace(".onnx", "_quant.onnx") 59 | if not os.path.exists(quant_model_path): 60 | onnx_model = onnx.load(model_path) 61 | nodes = [n.name for n in onnx_model.graph.node] 62 | nodes_to_exclude = [ 63 | m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m 64 | ] 65 | quantize_dynamic( 66 | model_input=model_path, 67 | model_output=quant_model_path, 68 | op_types_to_quantize=["MatMul"], 69 | per_channel=True, 70 | reduce_range=False, 71 | weight_type=QuantType.QUInt8, 72 | nodes_to_exclude=nodes_to_exclude, 73 | ) 74 | -------------------------------------------------------------------------------- /sensevoice/utils/model_bin.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- encoding: utf-8 -*- 3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. 4 | # MIT License (https://opensource.org/licenses/MIT) 5 | 6 | import os.path 7 | from pathlib import Path 8 | from typing import List, Union, Tuple 9 | import torch 10 | import librosa 11 | import numpy as np 12 | 13 | from utils.infer_utils import ( 14 | CharTokenizer, 15 | Hypothesis, 16 | ONNXRuntimeError, 17 | OrtInferSession, 18 | TokenIDConverter, 19 | get_logger, 20 | read_yaml, 21 | ) 22 | from utils.frontend import WavFrontend 23 | from utils.infer_utils import pad_list 24 | 25 | logging = get_logger() 26 | 27 | 28 | class SenseVoiceSmallONNX: 29 | """ 30 | Author: Speech Lab of DAMO Academy, Alibaba Group 31 | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition 32 | https://arxiv.org/abs/2206.08317 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model_dir: Union[str, Path] = None, 38 | batch_size: int = 1, 39 | device_id: Union[str, int] = "-1", 40 | plot_timestamp_to: str = "", 41 | quantize: bool = False, 42 | intra_op_num_threads: int = 4, 43 | cache_dir: str = None, 44 | **kwargs, 45 | ): 46 | if quantize: 47 | model_file = os.path.join(model_dir, "model_quant.onnx") 48 | else: 49 | model_file = os.path.join(model_dir, "model.onnx") 50 | 51 | config_file = os.path.join(model_dir, "config.yaml") 52 | cmvn_file = os.path.join(model_dir, "am.mvn") 53 | config = read_yaml(config_file) 54 | # token_list = os.path.join(model_dir, "tokens.json") 55 | # with open(token_list, "r", encoding="utf-8") as f: 56 | # token_list = json.load(f) 57 | 58 | # self.converter = TokenIDConverter(token_list) 59 | self.tokenizer = CharTokenizer() 60 | config["frontend_conf"]['cmvn_file'] = cmvn_file 61 | self.frontend = WavFrontend(**config["frontend_conf"]) 62 | self.ort_infer = OrtInferSession( 63 | model_file, device_id, intra_op_num_threads=intra_op_num_threads 64 | ) 65 | self.batch_size = batch_size 66 | self.blank_id = 0 67 | 68 | def __call__(self, 69 | wav_content: Union[str, np.ndarray, List[str]], 70 | language: List, 71 | textnorm: List, 72 | tokenizer=None, 73 | **kwargs) -> List: 74 | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) 75 | waveform_nums = len(waveform_list) 76 | asr_res = [] 77 | for beg_idx in range(0, waveform_nums, self.batch_size): 78 | end_idx = min(waveform_nums, beg_idx + self.batch_size) 79 | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) 80 | ctc_logits, encoder_out_lens = self.infer(feats, 81 | feats_len, 82 | np.array(language, dtype=np.int32), 83 | np.array(textnorm, dtype=np.int32) 84 | ) 85 | # back to torch.Tensor 86 | ctc_logits = torch.from_numpy(ctc_logits).float() 87 | # support batch_size=1 only currently 88 | x = ctc_logits[0, : encoder_out_lens[0].item(), :] 89 | yseq = x.argmax(dim=-1) 90 | yseq = torch.unique_consecutive(yseq, dim=-1) 91 | 92 | mask = yseq != self.blank_id 93 | token_int = yseq[mask].tolist() 94 | 95 | if tokenizer is not None: 96 | asr_res.append(tokenizer.tokens2text(token_int)) 97 | else: 98 | asr_res.append(token_int) 99 | return asr_res 100 | 101 | def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: 102 | def load_wav(path: str) -> np.ndarray: 103 | waveform, _ = librosa.load(path, sr=fs) 104 | return waveform 105 | 106 | if isinstance(wav_content, np.ndarray): 107 | return [wav_content] 108 | 109 | if isinstance(wav_content, str): 110 | return [load_wav(wav_content)] 111 | 112 | if isinstance(wav_content, list): 113 | return [load_wav(path) for path in wav_content] 114 | 115 | raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") 116 | 117 | def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 118 | feats, feats_len = [], [] 119 | for waveform in waveform_list: 120 | speech, _ = self.frontend.fbank(waveform) 121 | feat, feat_len = self.frontend.lfr_cmvn(speech) 122 | feats.append(feat) 123 | feats_len.append(feat_len) 124 | 125 | feats = self.pad_feats(feats, np.max(feats_len)) 126 | feats_len = np.array(feats_len).astype(np.int32) 127 | return feats, feats_len 128 | 129 | @staticmethod 130 | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: 131 | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: 132 | pad_width = ((0, max_feat_len - cur_len), (0, 0)) 133 | return np.pad(feat, pad_width, "constant", constant_values=0) 134 | 135 | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] 136 | feats = np.array(feat_res).astype(np.float32) 137 | return feats 138 | 139 | def infer(self, 140 | feats: np.ndarray, 141 | feats_len: np.ndarray, 142 | language: np.ndarray, 143 | textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]: 144 | outputs = self.ort_infer([feats, feats_len, language, textnorm]) 145 | return outputs 146 | -------------------------------------------------------------------------------- /web/PUT_WEB_JS_HERE: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/web/PUT_WEB_JS_HERE --------------------------------------------------------------------------------