├── .github
└── workflows
│ └── publish.yml
├── .gitignore
├── README.md
├── __init__.py
├── assets
├── LoadSpeakerModel.png
├── PuncSegment.png
├── SaveSpeakerModel.png
├── SenseVoice.png
└── Workflow_FunAudioLLM.png
├── cosyvoice
├── __init__.py
├── bin
│ ├── export_jit.py
│ ├── export_onnx.py
│ ├── inference.py
│ └── train.py
├── cli
│ ├── __init__.py
│ ├── cosyvoice.py
│ ├── frontend.py
│ └── model.py
├── dataset
│ ├── __init__.py
│ ├── dataset.py
│ └── processor.py
├── flow
│ ├── decoder.py
│ ├── flow.py
│ ├── flow_matching.py
│ └── length_regulator.py
├── hifigan
│ ├── f0_predictor.py
│ └── generator.py
├── llm
│ └── llm.py
├── tokenizer
│ ├── assets
│ │ └── multilingual_zh_ja_yue_char_del.tiktoken
│ └── tokenizer.py
├── transformer
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── convolution.py
│ ├── decoder.py
│ ├── decoder_layer.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── encoder_layer.py
│ ├── label_smoothing_loss.py
│ ├── positionwise_feed_forward.py
│ └── subsampling.py
└── utils
│ ├── __init__.py
│ ├── class_utils.py
│ ├── common.py
│ ├── executor.py
│ ├── file_utils.py
│ ├── frontend_utils.py
│ ├── mask.py
│ ├── scheduler.py
│ └── train_utils.py
├── funaudio_utils
├── __init__.py
├── cosyvoice_plus.py
├── download_models.py
└── pre.py
├── matcha
├── VERSION
├── __init__.py
├── app.py
├── cli.py
├── data
│ ├── __init__.py
│ ├── components
│ │ └── __init__.py
│ └── text_mel_datamodule.py
├── hifigan
│ ├── LICENSE
│ ├── README.md
│ ├── __init__.py
│ ├── config.py
│ ├── denoiser.py
│ ├── env.py
│ ├── meldataset.py
│ ├── models.py
│ └── xutils.py
├── models
│ ├── __init__.py
│ ├── baselightningmodule.py
│ ├── components
│ │ ├── __init__.py
│ │ ├── decoder.py
│ │ ├── flow_matching.py
│ │ ├── text_encoder.py
│ │ └── transformer.py
│ └── matcha_tts.py
├── onnx
│ ├── __init__.py
│ ├── export.py
│ └── infer.py
├── text
│ ├── __init__.py
│ ├── cleaners.py
│ ├── numbers.py
│ └── symbols.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── audio.py
│ ├── generate_data_statistics.py
│ ├── instantiators.py
│ ├── logging_utils.py
│ ├── model.py
│ ├── monotonic_align
│ ├── __init__.py
│ ├── core.pyx
│ └── setup.py
│ ├── pylogger.py
│ ├── rich_utils.py
│ └── utils.py
├── nodes
├── __init__.py
├── cosyvoice_nodes.py
└── sensevoice_nodes.py
├── pyproject.toml
├── requirements.txt
├── sensevoice
├── __init__.py
├── model.py
└── utils
│ ├── __init__.py
│ ├── export_utils.py
│ ├── frontend.py
│ ├── infer_utils.py
│ └── model_bin.py
├── web
└── PUT_WEB_JS_HERE
└── workflow
└── FunAudioLLM.json
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to Comfy registry
2 | on:
3 | workflow_dispatch:
4 | push:
5 | branches:
6 | - main
7 | - master
8 | paths:
9 | - "pyproject.toml"
10 |
11 | jobs:
12 | publish-node:
13 | name: Publish Custom Node to registry
14 | runs-on: ubuntu-latest
15 | # if this is a forked repository. Skipping the workflow.
16 | if: github.event.repository.fork == false
17 | steps:
18 | - name: Check out code
19 | uses: actions/checkout@v4
20 | - name: Publish Custom Node
21 | uses: Comfy-Org/publish-node-action@main
22 | with:
23 | ## Add your own personal access token to your Github Repository secrets and reference it here.
24 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }}
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
9 | # ComfyUI-FunAudioLLM
10 | Comfyui custom node for [FunAudioLLM](https://funaudiollm.github.io/) include [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) and [SenseVoice](https://github.com/FunAudioLLM/SenseVoice)
11 |
12 | ## Features
13 |
14 | ### CosyVoice
15 | - CosyVoice Version: 2024-10-04
16 | - Support SFT,Zero-shot,Cross-lingual,Instruct
17 | - Support CosyVoice-300M-25Hz in zero-shot and cross-lingual
18 | - Support SFT's 25Hz(unoffical)
19 | -
20 | Save and load speaker model in zero-shot
21 |
22 |
23 |
24 |
25 | ### SenseVoice
26 | - SenseVoice Version: 2024-10-04
27 | - Support SenseVoice-Small
28 | -
29 | Support Punctuation segment (need turn off use_fast_mode)
30 |
31 |
32 |
33 |
34 | ## How use
35 | ```bash
36 | apt update
37 | apt install ffmpeg
38 |
39 | ## in ComfyUI/custom_nodes
40 | git clone https://github.com/SpenserCai/ComfyUI-FunAudioLLM
41 | cd ComfyUI-FunAudioLLM
42 | pip install -r requirements.txt
43 |
44 | ```
45 |
46 | ### Windows
47 | In windows need use conda to install pynini
48 | ```bash
49 | conda install -c conda-forge pynini=2.1.6
50 | pip install -r requirements.txt
51 |
52 | ```
53 |
54 | ### MacOS
55 | If meet error when you install
56 | ```bash
57 | brew install openfst
58 | export CPPFLAGS="-I/opt/homebrew/include"
59 | export LDFLAGS="-L/opt/homebrew/lib"
60 | pip install -r requirements.txt
61 | ```
62 |
63 | If your network is unstable, you can pre-download the model from the following sources and place it in the appropriate directory.
64 |
65 | - [CosyVoice-300M](https://modelscope.cn/models/iic/CosyVoice-300M) -> `ComfyUI/models/CosyVoice/CosyVoice-300M`
66 | - [CosyVoice-300M-25Hz](https://modelscope.cn/models/iic/CosyVoice-300M-25Hz) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-25Hz`
67 | - [CosyVoice-300M-SFT](https://modelscope.cn/models/iic/CosyVoice-300M-SFT) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-SFT`
68 | - [CosyVoice-300M-SFT-25Hz](https://modelscope.cn/models/MachineS/CosyVoice-300M-SFT-25Hz) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-SFT-25Hz`
69 | - [CosyVoice-300M-Instruct](https://modelscope.cn/models/iic/CosyVoice-300M-Instruct) -> `ComfyUI/models/CosyVoice/CosyVoice-300M-Instruct`
70 | - [SenseVoiceSmall](https://modelscope.cn/models/iic/SenseVoiceSmall) -> `ComfyUI/models/SenseVoice/SenseVoiceSmall`
71 |
72 | ## WorkFlow
73 |
74 |
75 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 12:14:22
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-04 22:29:33
7 | Description: file content
8 | '''
9 | from .nodes.cosyvoice_nodes import *
10 | from .nodes.sensevoice_nodes import *
11 |
12 | NODE_CONFIG = {
13 | "CosyVoiceZeroShotNode": {
14 | "class": CosyVoiceZeroShotNode,
15 | "name": "CosyVoice 3s极速克隆"
16 | },
17 | "CosyVoiceSFTNode": {
18 | "class": CosyVoiceSFTNode,
19 | "name": "CosyVoice 预训练音色"
20 | },
21 | "CosyVoiceCrossLingualNode": {
22 | "class": CosyVoiceCrossLingualNode,
23 | "name": "CosyVoice 跨语言克隆"
24 | },
25 | "CosyVoiceInstructNode": {
26 | "class": CosyVoiceInstructNode,
27 | "name": "CosyVoice 自然语言控制"
28 | },
29 | "CosyVoiceSaveSpeakerModelNode": {
30 | "class": CosyVoiceSaveSpeakerModelNode,
31 | "name": "CosyVoice 保存说话人模型"
32 | },
33 | "CosyVoiceLoadSpeakerModelNode": {
34 | "class": CosyVoiceLoadSpeakerModelNode,
35 | "name": "CosyVoice 加载说话人模型"
36 | },
37 | "CosyVoiceLoadSpeakerModelFromUrlNode": {
38 | "class": CosyVoiceLoadSpeakerModelFromUrlNode,
39 | "name": "CosyVoice 从URL加载说话人模型"
40 | },
41 | "SenseVoiceNode": {
42 | "class": SenseVoiceNode,
43 | "name": "SenseVoice 语音识别"
44 | }
45 | }
46 |
47 | def generate_node_mappings(node_config):
48 | node_class_mappings = {}
49 | node_display_name_mappings = {}
50 |
51 | for node_name, node_info in node_config.items():
52 | node_class_mappings[node_name] = node_info["class"]
53 | node_display_name_mappings[node_name] = node_info.get("name", node_info["class"].__name__)
54 |
55 | return node_class_mappings, node_display_name_mappings
56 |
57 | NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS = generate_node_mappings(NODE_CONFIG)
58 |
59 | WEB_DIRECTORY = "./web"
60 |
61 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS", "WEB_DIRECTORY"]
62 |
63 |
64 |
--------------------------------------------------------------------------------
/assets/LoadSpeakerModel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/LoadSpeakerModel.png
--------------------------------------------------------------------------------
/assets/PuncSegment.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/PuncSegment.png
--------------------------------------------------------------------------------
/assets/SaveSpeakerModel.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/SaveSpeakerModel.png
--------------------------------------------------------------------------------
/assets/SenseVoice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/SenseVoice.png
--------------------------------------------------------------------------------
/assets/Workflow_FunAudioLLM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/assets/Workflow_FunAudioLLM.png
--------------------------------------------------------------------------------
/cosyvoice/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/bin/export_jit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 |
17 | import argparse
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | import os
21 | import sys
22 | import torch
23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24 | sys.path.append('{}/../..'.format(ROOT_DIR))
25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26 | from cosyvoice.cli.cosyvoice import CosyVoice
27 |
28 |
29 | def get_args():
30 | parser = argparse.ArgumentParser(description='export your model for deployment')
31 | parser.add_argument('--model_dir',
32 | type=str,
33 | default='pretrained_models/CosyVoice-300M',
34 | help='local path')
35 | args = parser.parse_args()
36 | print(args)
37 | return args
38 |
39 |
40 | def main():
41 | args = get_args()
42 | logging.basicConfig(level=logging.DEBUG,
43 | format='%(asctime)s %(levelname)s %(message)s')
44 |
45 | torch._C._jit_set_fusion_strategy([('STATIC', 1)])
46 | torch._C._jit_set_profiling_mode(False)
47 | torch._C._jit_set_profiling_executor(False)
48 |
49 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
50 |
51 | # 1. export llm text_encoder
52 | llm_text_encoder = cosyvoice.model.llm.text_encoder.half()
53 | script = torch.jit.script(llm_text_encoder)
54 | script = torch.jit.freeze(script)
55 | script = torch.jit.optimize_for_inference(script)
56 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
57 |
58 | # 2. export llm llm
59 | llm_llm = cosyvoice.model.llm.llm.half()
60 | script = torch.jit.script(llm_llm)
61 | script = torch.jit.freeze(script, preserved_attrs=['forward_chunk'])
62 | script = torch.jit.optimize_for_inference(script)
63 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
64 |
65 | # 3. export flow encoder
66 | flow_encoder = cosyvoice.model.flow.encoder
67 | script = torch.jit.script(flow_encoder)
68 | script = torch.jit.freeze(script)
69 | script = torch.jit.optimize_for_inference(script)
70 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
--------------------------------------------------------------------------------
/cosyvoice/bin/export_onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import print_function
17 |
18 | import argparse
19 | import logging
20 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
21 | import os
22 | import sys
23 | import onnxruntime
24 | import random
25 | import torch
26 | from tqdm import tqdm
27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28 | sys.path.append('{}/../..'.format(ROOT_DIR))
29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30 | from cosyvoice.cli.cosyvoice import CosyVoice
31 |
32 |
33 | def get_dummy_input(batch_size, seq_len, out_channels, device):
34 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
35 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
36 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
37 | t = torch.rand((batch_size), dtype=torch.float32, device=device)
38 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
39 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
40 | return x, mask, mu, t, spks, cond
41 |
42 |
43 | def get_args():
44 | parser = argparse.ArgumentParser(description='export your model for deployment')
45 | parser.add_argument('--model_dir',
46 | type=str,
47 | default='pretrained_models/CosyVoice-300M',
48 | help='local path')
49 | args = parser.parse_args()
50 | print(args)
51 | return args
52 |
53 |
54 | def main():
55 | args = get_args()
56 | logging.basicConfig(level=logging.DEBUG,
57 | format='%(asctime)s %(levelname)s %(message)s')
58 |
59 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False)
60 |
61 | # 1. export flow decoder estimator
62 | estimator = cosyvoice.model.flow.decoder.estimator
63 |
64 | device = cosyvoice.model.device
65 | batch_size, seq_len = 1, 256
66 | out_channels = cosyvoice.model.flow.decoder.estimator.out_channels
67 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
68 | torch.onnx.export(
69 | estimator,
70 | (x, mask, mu, t, spks, cond),
71 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
72 | export_params=True,
73 | opset_version=18,
74 | do_constant_folding=True,
75 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
76 | output_names=['estimator_out'],
77 | dynamic_axes={
78 | 'x': {0: 'batch_size', 2: 'seq_len'},
79 | 'mask': {0: 'batch_size', 2: 'seq_len'},
80 | 'mu': {0: 'batch_size', 2: 'seq_len'},
81 | 'cond': {0: 'batch_size', 2: 'seq_len'},
82 | 't': {0: 'batch_size'},
83 | 'spks': {0: 'batch_size'},
84 | 'estimator_out': {0: 'batch_size', 2: 'seq_len'},
85 | }
86 | )
87 |
88 | # 2. test computation consistency
89 | option = onnxruntime.SessionOptions()
90 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91 | option.intra_op_num_threads = 1
92 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94 | sess_options=option, providers=providers)
95 |
96 | for _ in tqdm(range(10)):
97 | x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
98 | output_pytorch = estimator(x, mask, mu, t, spks, cond)
99 | ort_inputs = {
100 | 'x': x.cpu().numpy(),
101 | 'mask': mask.cpu().numpy(),
102 | 'mu': mu.cpu().numpy(),
103 | 't': t.cpu().numpy(),
104 | 'spks': spks.cpu().numpy(),
105 | 'cond': cond.cpu().numpy()
106 | }
107 | output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
109 |
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/cosyvoice/bin/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 |
17 | import argparse
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | import os
21 | import torch
22 | from torch.utils.data import DataLoader
23 | import torchaudio
24 | from hyperpyyaml import load_hyperpyyaml
25 | from tqdm import tqdm
26 | from cosyvoice.cli.model import CosyVoiceModel
27 | from cosyvoice.dataset.dataset import Dataset
28 |
29 |
30 | def get_args():
31 | parser = argparse.ArgumentParser(description='inference with your model')
32 | parser.add_argument('--config', required=True, help='config file')
33 | parser.add_argument('--prompt_data', required=True, help='prompt data file')
34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35 | parser.add_argument('--tts_text', required=True, help='tts input file')
36 | parser.add_argument('--llm_model', required=True, help='llm model file')
37 | parser.add_argument('--flow_model', required=True, help='flow model file')
38 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
39 | parser.add_argument('--gpu',
40 | type=int,
41 | default=-1,
42 | help='gpu id for this rank, -1 for cpu')
43 | parser.add_argument('--mode',
44 | default='sft',
45 | choices=['sft', 'zero_shot'],
46 | help='inference mode')
47 | parser.add_argument('--result_dir', required=True, help='asr result file')
48 | args = parser.parse_args()
49 | print(args)
50 | return args
51 |
52 |
53 | def main():
54 | args = get_args()
55 | logging.basicConfig(level=logging.DEBUG,
56 | format='%(asctime)s %(levelname)s %(message)s')
57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
58 |
59 | # Init cosyvoice models from configs
60 | use_cuda = args.gpu >= 0 and torch.cuda.is_available()
61 | device = torch.device('cuda' if use_cuda else 'cpu')
62 | with open(args.config, 'r') as f:
63 | configs = load_hyperpyyaml(f)
64 |
65 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66 | model.load(args.llm_model, args.flow_model, args.hifigan_model)
67 |
68 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71 |
72 | del configs
73 | os.makedirs(args.result_dir, exist_ok=True)
74 | fn = os.path.join(args.result_dir, 'wav.scp')
75 | f = open(fn, 'w')
76 | with torch.no_grad():
77 | for _, batch in tqdm(enumerate(test_data_loader)):
78 | utts = batch["utts"]
79 | assert len(utts) == 1, "inference mode only support batchsize 1"
80 | text_token = batch["text_token"].to(device)
81 | text_token_len = batch["text_token_len"].to(device)
82 | tts_index = batch["tts_index"]
83 | tts_text_token = batch["tts_text_token"].to(device)
84 | tts_text_token_len = batch["tts_text_token_len"].to(device)
85 | speech_token = batch["speech_token"].to(device)
86 | speech_token_len = batch["speech_token_len"].to(device)
87 | speech_feat = batch["speech_feat"].to(device)
88 | speech_feat_len = batch["speech_feat_len"].to(device)
89 | utt_embedding = batch["utt_embedding"].to(device)
90 | spk_embedding = batch["spk_embedding"].to(device)
91 | if args.mode == 'sft':
92 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
93 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
94 | else:
95 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
96 | 'prompt_text': text_token, 'prompt_text_len': text_token_len,
97 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
98 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
99 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
100 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
101 | tts_speeches = []
102 | for model_output in model.inference(**model_input):
103 | tts_speeches.append(model_output['tts_speech'])
104 | tts_speeches = torch.concat(tts_speeches, dim=1)
105 | tts_key = '{}_{}'.format(utts[0], tts_index[0])
106 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
107 | torchaudio.save(tts_fn, tts_speeches, sample_rate=22050)
108 | f.write('{} {}\n'.format(tts_key, tts_fn))
109 | f.flush()
110 | f.close()
111 | logging.info('Result wav.scp saved in {}'.format(fn))
112 |
113 |
114 | if __name__ == '__main__':
115 | main()
116 |
--------------------------------------------------------------------------------
/cosyvoice/bin/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 | import argparse
17 | import datetime
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | from copy import deepcopy
21 | import torch
22 | import torch.distributed as dist
23 | import deepspeed
24 |
25 | from hyperpyyaml import load_hyperpyyaml
26 |
27 | from torch.distributed.elastic.multiprocessing.errors import record
28 |
29 | from cosyvoice.utils.executor import Executor
30 | from cosyvoice.utils.train_utils import (
31 | init_distributed,
32 | init_dataset_and_dataloader,
33 | init_optimizer_and_scheduler,
34 | init_summarywriter, save_model,
35 | wrap_cuda_model, check_modify_and_save_config)
36 |
37 |
38 | def get_args():
39 | parser = argparse.ArgumentParser(description='training your network')
40 | parser.add_argument('--train_engine',
41 | default='torch_ddp',
42 | choices=['torch_ddp', 'deepspeed'],
43 | help='Engine for paralleled training')
44 | parser.add_argument('--model', required=True, help='model which will be trained')
45 | parser.add_argument('--config', required=True, help='config file')
46 | parser.add_argument('--train_data', required=True, help='train data file')
47 | parser.add_argument('--cv_data', required=True, help='cv data file')
48 | parser.add_argument('--checkpoint', help='checkpoint model')
49 | parser.add_argument('--model_dir', required=True, help='save model dir')
50 | parser.add_argument('--tensorboard_dir',
51 | default='tensorboard',
52 | help='tensorboard log dir')
53 | parser.add_argument('--ddp.dist_backend',
54 | dest='dist_backend',
55 | default='nccl',
56 | choices=['nccl', 'gloo'],
57 | help='distributed backend')
58 | parser.add_argument('--num_workers',
59 | default=0,
60 | type=int,
61 | help='num of subprocess workers for reading')
62 | parser.add_argument('--prefetch',
63 | default=100,
64 | type=int,
65 | help='prefetch number')
66 | parser.add_argument('--pin_memory',
67 | action='store_true',
68 | default=False,
69 | help='Use pinned memory buffers used for reading')
70 | parser.add_argument('--deepspeed.save_states',
71 | dest='save_states',
72 | default='model_only',
73 | choices=['model_only', 'model+optimizer'],
74 | help='save model/optimizer states')
75 | parser.add_argument('--timeout',
76 | default=30,
77 | type=int,
78 | help='timeout (in seconds) of cosyvoice_join.')
79 | parser = deepspeed.add_config_arguments(parser)
80 | args = parser.parse_args()
81 | return args
82 |
83 |
84 | @record
85 | def main():
86 | args = get_args()
87 | logging.basicConfig(level=logging.DEBUG,
88 | format='%(asctime)s %(levelname)s %(message)s')
89 |
90 | override_dict = {k: None for k in ['llm', 'flow', 'hift'] if k != args.model}
91 | with open(args.config, 'r') as f:
92 | configs = load_hyperpyyaml(f, overrides=override_dict)
93 | configs['train_conf'].update(vars(args))
94 |
95 | # Init env for ddp
96 | init_distributed(args)
97 |
98 | # Get dataset & dataloader
99 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
100 | init_dataset_and_dataloader(args, configs)
101 |
102 | # Do some sanity checks and save config to arsg.model_dir
103 | configs = check_modify_and_save_config(args, configs)
104 |
105 | # Tensorboard summary
106 | writer = init_summarywriter(args)
107 |
108 | # load checkpoint
109 | model = configs[args.model]
110 | if args.checkpoint is not None:
111 | model.load_state_dict(torch.load(args.checkpoint, map_location='cpu'))
112 |
113 | # Dispatch model from cpu to gpu
114 | model = wrap_cuda_model(args, model)
115 |
116 | # Get optimizer & scheduler
117 | model, optimizer, scheduler = init_optimizer_and_scheduler(args, configs, model)
118 |
119 | # Save init checkpoints
120 | info_dict = deepcopy(configs['train_conf'])
121 | save_model(model, 'init', info_dict)
122 |
123 | # Get executor
124 | executor = Executor()
125 |
126 | # Start training loop
127 | for epoch in range(info_dict['max_epoch']):
128 | executor.epoch = epoch
129 | train_dataset.set_epoch(epoch)
130 | dist.barrier()
131 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
132 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
133 | dist.destroy_process_group(group_join)
134 |
135 |
136 | if __name__ == '__main__':
137 | main()
138 |
--------------------------------------------------------------------------------
/cosyvoice/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/cli/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/cli/cosyvoice.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 11:30:15
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-04 14:31:23
7 | Description: file content
8 | '''
9 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
10 | #
11 | # Licensed under the Apache License, Version 2.0 (the "License");
12 | # you may not use this file except in compliance with the License.
13 | # You may obtain a copy of the License at
14 | #
15 | # http://www.apache.org/licenses/LICENSE-2.0
16 | #
17 | # Unless required by applicable law or agreed to in writing, software
18 | # distributed under the License is distributed on an "AS IS" BASIS,
19 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20 | # See the License for the specific language governing permissions and
21 | # limitations under the License.
22 | import os
23 | import time
24 | from tqdm import tqdm
25 | from hyperpyyaml import load_hyperpyyaml
26 | from modelscope import snapshot_download
27 | from cosyvoice.cli.frontend import CosyVoiceFrontEnd
28 | from cosyvoice.cli.model import CosyVoiceModel
29 | from cosyvoice.utils.file_utils import logging
30 |
31 |
32 | class CosyVoice:
33 |
34 | def __init__(self, model_dir, load_jit=True, load_onnx=False):
35 | instruct = True if '-Instruct' in model_dir else False
36 | self.model_dir = model_dir
37 | if not os.path.exists(model_dir):
38 | model_dir = snapshot_download(model_dir)
39 | with open('{}/cosyvoice.yaml'.format(model_dir), 'r') as f:
40 | configs = load_hyperpyyaml(f)
41 | self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
42 | configs['feat_extractor'],
43 | '{}/campplus.onnx'.format(model_dir),
44 | '{}/speech_tokenizer_v1.onnx'.format(model_dir),
45 | '{}/spk2info.pt'.format(model_dir),
46 | instruct,
47 | configs['allowed_special'])
48 | self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
49 | self.model.load('{}/llm.pt'.format(model_dir),
50 | '{}/flow.pt'.format(model_dir),
51 | '{}/hift.pt'.format(model_dir))
52 | if load_jit:
53 | self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
54 | '{}/llm.llm.fp16.zip'.format(model_dir),
55 | '{}/flow.encoder.fp32.zip'.format(model_dir))
56 | if load_onnx:
57 | self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
58 | del configs
59 |
60 | def list_avaliable_spks(self):
61 | spks = list(self.frontend.spk2info.keys())
62 | return spks
63 |
64 | def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
65 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
66 | model_input = self.frontend.frontend_sft(i, spk_id)
67 | start_time = time.time()
68 | logging.info('synthesis text {}'.format(i))
69 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
70 | speech_len = model_output['tts_speech'].shape[1] / 22050
71 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
72 | yield model_output
73 | start_time = time.time()
74 |
75 | def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False, speed=1.0):
76 | prompt_text = self.frontend.text_normalize(prompt_text, split=False)
77 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
78 | model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
79 | start_time = time.time()
80 | logging.info('synthesis text {}'.format(i))
81 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
82 | speech_len = model_output['tts_speech'].shape[1] / 22050
83 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
84 | yield model_output
85 | start_time = time.time()
86 |
87 | def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, speed=1.0):
88 | if self.frontend.instruct is True:
89 | raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
90 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
91 | model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
92 | start_time = time.time()
93 | logging.info('synthesis text {}'.format(i))
94 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
95 | speech_len = model_output['tts_speech'].shape[1] / 22050
96 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
97 | yield model_output
98 | start_time = time.time()
99 |
100 | def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0):
101 | if self.frontend.instruct is False:
102 | raise ValueError('{} do not support instruct inference'.format(self.model_dir))
103 | instruct_text = self.frontend.text_normalize(instruct_text, split=False)
104 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
105 | model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
106 | start_time = time.time()
107 | logging.info('synthesis text {}'.format(i))
108 | for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
109 | speech_len = model_output['tts_speech'].shape[1] / 22050
110 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
111 | yield model_output
112 | start_time = time.time()
113 |
114 | def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
115 | model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
116 | start_time = time.time()
117 | for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
118 | speech_len = model_output['tts_speech'].shape[1] / 22050
119 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
120 | yield model_output
121 | start_time = time.time()
122 |
--------------------------------------------------------------------------------
/cosyvoice/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/dataset/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import random
17 | import json
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch.utils.data import IterableDataset
24 | from cosyvoice.utils.file_utils import read_lists, read_json_lists
25 |
26 |
27 | class Processor(IterableDataset):
28 |
29 | def __init__(self, source, f, *args, **kw):
30 | assert callable(f)
31 | self.source = source
32 | self.f = f
33 | self.args = args
34 | self.kw = kw
35 |
36 | def set_epoch(self, epoch):
37 | self.source.set_epoch(epoch)
38 |
39 | def __iter__(self):
40 | """ Return an iterator over the source dataset processed by the
41 | given processor.
42 | """
43 | assert self.source is not None
44 | assert callable(self.f)
45 | return self.f(iter(self.source), *self.args, **self.kw)
46 |
47 | def apply(self, f):
48 | assert callable(f)
49 | return Processor(self, f, *self.args, **self.kw)
50 |
51 |
52 | class DistributedSampler:
53 |
54 | def __init__(self, shuffle=True, partition=True):
55 | self.epoch = -1
56 | self.update()
57 | self.shuffle = shuffle
58 | self.partition = partition
59 |
60 | def update(self):
61 | assert dist.is_available()
62 | if dist.is_initialized():
63 | self.rank = dist.get_rank()
64 | self.world_size = dist.get_world_size()
65 | else:
66 | self.rank = 0
67 | self.world_size = 1
68 | worker_info = torch.utils.data.get_worker_info()
69 | if worker_info is None:
70 | self.worker_id = 0
71 | self.num_workers = 1
72 | else:
73 | self.worker_id = worker_info.id
74 | self.num_workers = worker_info.num_workers
75 | return dict(rank=self.rank,
76 | world_size=self.world_size,
77 | worker_id=self.worker_id,
78 | num_workers=self.num_workers)
79 |
80 | def set_epoch(self, epoch):
81 | self.epoch = epoch
82 |
83 | def sample(self, data):
84 | """ Sample data according to rank/world_size/num_workers
85 |
86 | Args:
87 | data(List): input data list
88 |
89 | Returns:
90 | List: data list after sample
91 | """
92 | data = list(range(len(data)))
93 | # force datalist even
94 | if self.partition:
95 | if self.shuffle:
96 | random.Random(self.epoch).shuffle(data)
97 | if len(data) < self.world_size:
98 | data = data * math.ceil(self.world_size / len(data))
99 | data = data[:self.world_size]
100 | data = data[self.rank::self.world_size]
101 | if len(data) < self.num_workers:
102 | data = data * math.ceil(self.num_workers / len(data))
103 | data = data[:self.num_workers]
104 | data = data[self.worker_id::self.num_workers]
105 | return data
106 |
107 |
108 | class DataList(IterableDataset):
109 |
110 | def __init__(self, lists, shuffle=True, partition=True):
111 | self.lists = lists
112 | self.sampler = DistributedSampler(shuffle, partition)
113 |
114 | def set_epoch(self, epoch):
115 | self.sampler.set_epoch(epoch)
116 |
117 | def __iter__(self):
118 | sampler_info = self.sampler.update()
119 | indexes = self.sampler.sample(self.lists)
120 | for index in indexes:
121 | data = dict(src=self.lists[index])
122 | data.update(sampler_info)
123 | yield data
124 |
125 |
126 | def Dataset(data_list_file,
127 | data_pipeline,
128 | mode='train',
129 | shuffle=True,
130 | partition=True,
131 | tts_file='',
132 | prompt_utt2data=''):
133 | """ Construct dataset from arguments
134 |
135 | We have two shuffle stage in the Dataset. The first is global
136 | shuffle at shards tar/raw file level. The second is global shuffle
137 | at training samples level.
138 |
139 | Args:
140 | data_type(str): raw/shard
141 | tokenizer (BaseTokenizer): tokenizer to tokenize
142 | partition(bool): whether to do data partition in terms of rank
143 | """
144 | assert mode in ['train', 'inference']
145 | lists = read_lists(data_list_file)
146 | if mode == 'inference':
147 | with open(tts_file) as f:
148 | tts_data = json.load(f)
149 | utt2lists = read_json_lists(prompt_utt2data)
150 | # filter unnecessary file in inference mode
151 | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
152 | dataset = DataList(lists,
153 | shuffle=shuffle,
154 | partition=partition)
155 | if mode == 'inference':
156 | # map partial arg tts_data in inference mode
157 | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
158 | for func in data_pipeline:
159 | dataset = Processor(dataset, func, mode=mode)
160 | return dataset
161 |
--------------------------------------------------------------------------------
/cosyvoice/flow/flow.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import logging
15 | import random
16 | from typing import Dict, Optional
17 | import torch
18 | import torch.nn as nn
19 | from torch.nn import functional as F
20 | from omegaconf import DictConfig
21 | from cosyvoice.utils.mask import make_pad_mask
22 |
23 |
24 | class MaskedDiffWithXvec(torch.nn.Module):
25 | def __init__(self,
26 | input_size: int = 512,
27 | output_size: int = 80,
28 | spk_embed_dim: int = 192,
29 | output_type: str = "mel",
30 | vocab_size: int = 4096,
31 | input_frame_rate: int = 50,
32 | only_mask_loss: bool = True,
33 | encoder: torch.nn.Module = None,
34 | length_regulator: torch.nn.Module = None,
35 | decoder: torch.nn.Module = None,
36 | decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37 | 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38 | 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39 | 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40 | 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41 | mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42 | 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43 | super().__init__()
44 | self.input_size = input_size
45 | self.output_size = output_size
46 | self.decoder_conf = decoder_conf
47 | self.mel_feat_conf = mel_feat_conf
48 | self.vocab_size = vocab_size
49 | self.output_type = output_type
50 | self.input_frame_rate = input_frame_rate
51 | logging.info(f"input frame rate={self.input_frame_rate}")
52 | self.input_embedding = nn.Embedding(vocab_size, input_size)
53 | self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
54 | self.encoder = encoder
55 | self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
56 | self.decoder = decoder
57 | self.length_regulator = length_regulator
58 | self.only_mask_loss = only_mask_loss
59 |
60 | def forward(
61 | self,
62 | batch: dict,
63 | device: torch.device,
64 | ) -> Dict[str, Optional[torch.Tensor]]:
65 | token = batch['speech_token'].to(device)
66 | token_len = batch['speech_token_len'].to(device)
67 | feat = batch['speech_feat'].to(device)
68 | feat_len = batch['speech_feat_len'].to(device)
69 | embedding = batch['embedding'].to(device)
70 |
71 | # xvec projection
72 | embedding = F.normalize(embedding, dim=1)
73 | embedding = self.spk_embed_affine_layer(embedding)
74 |
75 | # concat text and prompt_text
76 | mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
77 | token = self.input_embedding(torch.clamp(token, min=0)) * mask
78 |
79 | # text encode
80 | h, h_lengths = self.encoder(token, token_len)
81 | h = self.encoder_proj(h)
82 | h, h_lengths = self.length_regulator(h, feat_len)
83 |
84 | # get conditions
85 | conds = torch.zeros(feat.shape, device=token.device)
86 | for i, j in enumerate(feat_len):
87 | if random.random() < 0.5:
88 | continue
89 | index = random.randint(0, int(0.3 * j))
90 | conds[i, :index] = feat[i, :index]
91 | conds = conds.transpose(1, 2)
92 |
93 | mask = (~make_pad_mask(feat_len)).to(h)
94 | feat = F.interpolate(feat.unsqueeze(dim=1), size=h.shape[1:], mode="nearest").squeeze(dim=1)
95 | loss, _ = self.decoder.compute_loss(
96 | feat.transpose(1, 2).contiguous(),
97 | mask.unsqueeze(1),
98 | h.transpose(1, 2).contiguous(),
99 | embedding,
100 | cond=conds
101 | )
102 | return {'loss': loss}
103 |
104 | @torch.inference_mode()
105 | def inference(self,
106 | token,
107 | token_len,
108 | prompt_token,
109 | prompt_token_len,
110 | prompt_feat,
111 | prompt_feat_len,
112 | embedding):
113 | assert token.shape[0] == 1
114 | # xvec projection
115 | embedding = F.normalize(embedding, dim=1)
116 | embedding = self.spk_embed_affine_layer(embedding)
117 |
118 | # concat text and prompt_text
119 | token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
120 | token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
121 | mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
122 | token = self.input_embedding(torch.clamp(token, min=0)) * mask
123 |
124 | # text encode
125 | h, h_lengths = self.encoder(token, token_len)
126 | h = self.encoder_proj(h)
127 | mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
128 | h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
129 |
130 | # get conditions
131 | conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
132 | conds[:, :mel_len1] = prompt_feat
133 | conds = conds.transpose(1, 2)
134 |
135 | mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
136 | feat = self.decoder(
137 | mu=h.transpose(1, 2).contiguous(),
138 | mask=mask.unsqueeze(1),
139 | spks=embedding,
140 | cond=conds,
141 | n_timesteps=10
142 | )
143 | feat = feat[:, :, mel_len1:]
144 | assert feat.shape[2] == mel_len2
145 | return feat
146 |
--------------------------------------------------------------------------------
/cosyvoice/flow/flow_matching.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import torch.nn.functional as F
16 | from matcha.models.components.flow_matching import BASECFM
17 |
18 |
19 | class ConditionalCFM(BASECFM):
20 | def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
21 | super().__init__(
22 | n_feats=in_channels,
23 | cfm_params=cfm_params,
24 | n_spks=n_spks,
25 | spk_emb_dim=spk_emb_dim,
26 | )
27 | self.t_scheduler = cfm_params.t_scheduler
28 | self.training_cfg_rate = cfm_params.training_cfg_rate
29 | self.inference_cfg_rate = cfm_params.inference_cfg_rate
30 | in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
31 | # Just change the architecture of the estimator here
32 | self.estimator = estimator
33 |
34 | @torch.inference_mode()
35 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
36 | """Forward diffusion
37 |
38 | Args:
39 | mu (torch.Tensor): output of encoder
40 | shape: (batch_size, n_feats, mel_timesteps)
41 | mask (torch.Tensor): output_mask
42 | shape: (batch_size, 1, mel_timesteps)
43 | n_timesteps (int): number of diffusion steps
44 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
45 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
46 | shape: (batch_size, spk_emb_dim)
47 | cond: Not used but kept for future purposes
48 |
49 | Returns:
50 | sample: generated mel-spectrogram
51 | shape: (batch_size, n_feats, mel_timesteps)
52 | """
53 | z = torch.randn_like(mu) * temperature
54 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
55 | if self.t_scheduler == 'cosine':
56 | t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
57 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
58 |
59 | def solve_euler(self, x, t_span, mu, mask, spks, cond):
60 | """
61 | Fixed euler solver for ODEs.
62 | Args:
63 | x (torch.Tensor): random noise
64 | t_span (torch.Tensor): n_timesteps interpolated
65 | shape: (n_timesteps + 1,)
66 | mu (torch.Tensor): output of encoder
67 | shape: (batch_size, n_feats, mel_timesteps)
68 | mask (torch.Tensor): output_mask
69 | shape: (batch_size, 1, mel_timesteps)
70 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
71 | shape: (batch_size, spk_emb_dim)
72 | cond: Not used but kept for future purposes
73 | """
74 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
75 | t = t.unsqueeze(dim=0)
76 |
77 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file
78 | # Or in future might add like a return_all_steps flag
79 | sol = []
80 |
81 | for step in range(1, len(t_span)):
82 | dphi_dt = self.forward_estimator(x, mask, mu, t, spks, cond)
83 | # Classifier-Free Guidance inference introduced in VoiceBox
84 | if self.inference_cfg_rate > 0:
85 | cfg_dphi_dt = self.forward_estimator(
86 | x, mask,
87 | torch.zeros_like(mu), t,
88 | torch.zeros_like(spks) if spks is not None else None,
89 | torch.zeros_like(cond)
90 | )
91 | dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt -
92 | self.inference_cfg_rate * cfg_dphi_dt)
93 | x = x + dt * dphi_dt
94 | t = t + dt
95 | sol.append(x)
96 | if step < len(t_span) - 1:
97 | dt = t_span[step + 1] - t
98 |
99 | return sol[-1]
100 |
101 | def forward_estimator(self, x, mask, mu, t, spks, cond):
102 | if isinstance(self.estimator, torch.nn.Module):
103 | return self.estimator.forward(x, mask, mu, t, spks, cond)
104 | else:
105 | ort_inputs = {
106 | 'x': x.cpu().numpy(),
107 | 'mask': mask.cpu().numpy(),
108 | 'mu': mu.cpu().numpy(),
109 | 't': t.cpu().numpy(),
110 | 'spks': spks.cpu().numpy(),
111 | 'cond': cond.cpu().numpy()
112 | }
113 | output = self.estimator.run(None, ort_inputs)[0]
114 | return torch.tensor(output, dtype=x.dtype, device=x.device)
115 |
116 | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
117 | """Computes diffusion loss
118 |
119 | Args:
120 | x1 (torch.Tensor): Target
121 | shape: (batch_size, n_feats, mel_timesteps)
122 | mask (torch.Tensor): target mask
123 | shape: (batch_size, 1, mel_timesteps)
124 | mu (torch.Tensor): output of encoder
125 | shape: (batch_size, n_feats, mel_timesteps)
126 | spks (torch.Tensor, optional): speaker embedding. Defaults to None.
127 | shape: (batch_size, spk_emb_dim)
128 |
129 | Returns:
130 | loss: conditional flow matching loss
131 | y: conditional flow
132 | shape: (batch_size, n_feats, mel_timesteps)
133 | """
134 | b, _, t = mu.shape
135 |
136 | # random timestep
137 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
138 | if self.t_scheduler == 'cosine':
139 | t = 1 - torch.cos(t * 0.5 * torch.pi)
140 | # sample noise p(x_0)
141 | z = torch.randn_like(x1)
142 |
143 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1
144 | u = x1 - (1 - self.sigma_min) * z
145 |
146 | # during training, we randomly drop condition to trade off mode coverage and sample fidelity
147 | if self.training_cfg_rate > 0:
148 | cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
149 | mu = mu * cfg_mask.view(-1, 1, 1)
150 | spks = spks * cfg_mask.view(-1, 1)
151 | cond = cond * cfg_mask.view(-1, 1, 1)
152 |
153 | pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond)
154 | loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
155 | return loss, y
156 |
--------------------------------------------------------------------------------
/cosyvoice/flow/length_regulator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Tuple
15 | import torch.nn as nn
16 | import torch
17 | from torch.nn import functional as F
18 | from cosyvoice.utils.mask import make_pad_mask
19 |
20 |
21 | class InterpolateRegulator(nn.Module):
22 | def __init__(
23 | self,
24 | channels: int,
25 | sampling_ratios: Tuple,
26 | out_channels: int = None,
27 | groups: int = 1,
28 | ):
29 | super().__init__()
30 | self.sampling_ratios = sampling_ratios
31 | out_channels = out_channels or channels
32 | model = nn.ModuleList([])
33 | if len(sampling_ratios) > 0:
34 | for _ in sampling_ratios:
35 | module = nn.Conv1d(channels, channels, 3, 1, 1)
36 | norm = nn.GroupNorm(groups, channels)
37 | act = nn.Mish()
38 | model.extend([module, norm, act])
39 | model.append(
40 | nn.Conv1d(channels, out_channels, 1, 1)
41 | )
42 | self.model = nn.Sequential(*model)
43 |
44 | def forward(self, x, ylens=None):
45 | # x in (B, T, D)
46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48 | out = self.model(x).transpose(1, 2).contiguous()
49 | olens = ylens
50 | return out * mask, olens
51 |
52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54 | # x in (B, T, D)
55 | if x2.shape[1] > 40:
56 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58 | mode='linear')
59 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61 | else:
62 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
63 | if x1.shape[1] != 0:
64 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
65 | x = torch.concat([x1, x2], dim=2)
66 | else:
67 | x = x2
68 | out = self.model(x).transpose(1, 2).contiguous()
69 | return out, mel_len1 + mel_len2
70 |
--------------------------------------------------------------------------------
/cosyvoice/hifigan/f0_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import torch.nn as nn
16 | from torch.nn.utils import weight_norm
17 |
18 |
19 | class ConvRNNF0Predictor(nn.Module):
20 | def __init__(self,
21 | num_class: int = 1,
22 | in_channels: int = 80,
23 | cond_channels: int = 512
24 | ):
25 | super().__init__()
26 |
27 | self.num_class = num_class
28 | self.condnet = nn.Sequential(
29 | weight_norm(
30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
31 | ),
32 | nn.ELU(),
33 | weight_norm(
34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
35 | ),
36 | nn.ELU(),
37 | weight_norm(
38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
39 | ),
40 | nn.ELU(),
41 | weight_norm(
42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
43 | ),
44 | nn.ELU(),
45 | weight_norm(
46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
47 | ),
48 | nn.ELU(),
49 | )
50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.condnet(x)
54 | x = x.transpose(1, 2)
55 | return torch.abs(self.classifier(x).squeeze(-1))
56 |
--------------------------------------------------------------------------------
/cosyvoice/tokenizer/tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | from functools import lru_cache
4 | from typing import Optional
5 | from whisper.tokenizer import Tokenizer
6 |
7 | import tiktoken
8 |
9 | LANGUAGES = {
10 | "en": "english",
11 | "zh": "chinese",
12 | "de": "german",
13 | "es": "spanish",
14 | "ru": "russian",
15 | "ko": "korean",
16 | "fr": "french",
17 | "ja": "japanese",
18 | "pt": "portuguese",
19 | "tr": "turkish",
20 | "pl": "polish",
21 | "ca": "catalan",
22 | "nl": "dutch",
23 | "ar": "arabic",
24 | "sv": "swedish",
25 | "it": "italian",
26 | "id": "indonesian",
27 | "hi": "hindi",
28 | "fi": "finnish",
29 | "vi": "vietnamese",
30 | "he": "hebrew",
31 | "uk": "ukrainian",
32 | "el": "greek",
33 | "ms": "malay",
34 | "cs": "czech",
35 | "ro": "romanian",
36 | "da": "danish",
37 | "hu": "hungarian",
38 | "ta": "tamil",
39 | "no": "norwegian",
40 | "th": "thai",
41 | "ur": "urdu",
42 | "hr": "croatian",
43 | "bg": "bulgarian",
44 | "lt": "lithuanian",
45 | "la": "latin",
46 | "mi": "maori",
47 | "ml": "malayalam",
48 | "cy": "welsh",
49 | "sk": "slovak",
50 | "te": "telugu",
51 | "fa": "persian",
52 | "lv": "latvian",
53 | "bn": "bengali",
54 | "sr": "serbian",
55 | "az": "azerbaijani",
56 | "sl": "slovenian",
57 | "kn": "kannada",
58 | "et": "estonian",
59 | "mk": "macedonian",
60 | "br": "breton",
61 | "eu": "basque",
62 | "is": "icelandic",
63 | "hy": "armenian",
64 | "ne": "nepali",
65 | "mn": "mongolian",
66 | "bs": "bosnian",
67 | "kk": "kazakh",
68 | "sq": "albanian",
69 | "sw": "swahili",
70 | "gl": "galician",
71 | "mr": "marathi",
72 | "pa": "punjabi",
73 | "si": "sinhala",
74 | "km": "khmer",
75 | "sn": "shona",
76 | "yo": "yoruba",
77 | "so": "somali",
78 | "af": "afrikaans",
79 | "oc": "occitan",
80 | "ka": "georgian",
81 | "be": "belarusian",
82 | "tg": "tajik",
83 | "sd": "sindhi",
84 | "gu": "gujarati",
85 | "am": "amharic",
86 | "yi": "yiddish",
87 | "lo": "lao",
88 | "uz": "uzbek",
89 | "fo": "faroese",
90 | "ht": "haitian creole",
91 | "ps": "pashto",
92 | "tk": "turkmen",
93 | "nn": "nynorsk",
94 | "mt": "maltese",
95 | "sa": "sanskrit",
96 | "lb": "luxembourgish",
97 | "my": "myanmar",
98 | "bo": "tibetan",
99 | "tl": "tagalog",
100 | "mg": "malagasy",
101 | "as": "assamese",
102 | "tt": "tatar",
103 | "haw": "hawaiian",
104 | "ln": "lingala",
105 | "ha": "hausa",
106 | "ba": "bashkir",
107 | "jw": "javanese",
108 | "su": "sundanese",
109 | "yue": "cantonese",
110 | "minnan": "minnan",
111 | "wuyu": "wuyu",
112 | "dialect": "dialect",
113 | "zh/en": "zh/en",
114 | "en/zh": "en/zh",
115 | }
116 |
117 | # language code lookup by name, with a few language aliases
118 | TO_LANGUAGE_CODE = {
119 | **{language: code for code, language in LANGUAGES.items()},
120 | "burmese": "my",
121 | "valencian": "ca",
122 | "flemish": "nl",
123 | "haitian": "ht",
124 | "letzeburgesch": "lb",
125 | "pushto": "ps",
126 | "panjabi": "pa",
127 | "moldavian": "ro",
128 | "moldovan": "ro",
129 | "sinhalese": "si",
130 | "castilian": "es",
131 | "mandarin": "zh",
132 | }
133 |
134 | AUDIO_EVENT = {
135 | "ASR": "ASR",
136 | "AED": "AED",
137 | "SER": "SER",
138 | "Speech": "Speech",
139 | "/Speech": "/Speech",
140 | "BGM": "BGM",
141 | "/BGM": "/BGM",
142 | "Laughter": "Laughter",
143 | "/Laughter": "/Laughter",
144 | "Applause": "Applause",
145 | "/Applause": "/Applause",
146 | }
147 |
148 | EMOTION = {
149 | "HAPPY": "HAPPY",
150 | "SAD": "SAD",
151 | "ANGRY": "ANGRY",
152 | "NEUTRAL": "NEUTRAL",
153 | }
154 |
155 | TTS_Vocal_Token = {
156 | "TTS/B": "TTS/B",
157 | "TTS/O": "TTS/O",
158 | "TTS/Q": "TTS/Q",
159 | "TTS/A": "TTS/A",
160 | "TTS/CO": "TTS/CO",
161 | "TTS/CL": "TTS/CL",
162 | "TTS/H": "TTS/H",
163 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
164 | }
165 |
166 |
167 | @lru_cache(maxsize=None)
168 | def get_encoding(name: str = "gpt2", num_languages: int = 99):
169 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
170 | ranks = {
171 | base64.b64decode(token): int(rank)
172 | for token, rank in (line.split() for line in open(vocab_path) if line)
173 | }
174 | n_vocab = len(ranks)
175 | special_tokens = {}
176 |
177 | specials = [
178 | "<|endoftext|>",
179 | "<|startoftranscript|>",
180 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
181 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
182 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
183 | "<|translate|>",
184 | "<|transcribe|>",
185 | "<|startoflm|>",
186 | "<|startofprev|>",
187 | "<|nospeech|>",
188 | "<|notimestamps|>",
189 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
190 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
191 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
192 | ]
193 |
194 | for token in specials:
195 | special_tokens[token] = n_vocab
196 | n_vocab += 1
197 |
198 | return tiktoken.Encoding(
199 | name=os.path.basename(vocab_path),
200 | explicit_n_vocab=n_vocab,
201 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
202 | mergeable_ranks=ranks,
203 | special_tokens=special_tokens,
204 | )
205 |
206 |
207 | @lru_cache(maxsize=None)
208 | def get_tokenizer(
209 | multilingual: bool,
210 | *,
211 | num_languages: int = 99,
212 | language: Optional[str] = None,
213 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
214 | ) -> Tokenizer:
215 | if language is not None:
216 | language = language.lower()
217 | if language not in LANGUAGES:
218 | if language in TO_LANGUAGE_CODE:
219 | language = TO_LANGUAGE_CODE[language]
220 | else:
221 | raise ValueError(f"Unsupported language: {language}")
222 |
223 | if multilingual:
224 | encoding_name = "multilingual_zh_ja_yue_char_del"
225 | language = language or "en"
226 | task = task or "transcribe"
227 | else:
228 | encoding_name = "gpt2"
229 | language = None
230 | task = None
231 |
232 | encoding = get_encoding(name=encoding_name, num_languages=num_languages)
233 |
234 | return Tokenizer(
235 | encoding=encoding, num_languages=num_languages, language=language, task=task
236 | )
237 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/transformer/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/transformer/activation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3 | # 2020 Mobvoi Inc (Binbin Zhang)
4 | # 2024 Alibaba Inc (Xiang Lyu)
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """Swish() activation function for Conformer."""
18 |
19 | import torch
20 | from torch import nn, sin, pow
21 | from torch.nn import Parameter
22 |
23 |
24 | class Swish(torch.nn.Module):
25 | """Construct an Swish object."""
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | """Return Swish activation function."""
29 | return x * torch.sigmoid(x)
30 |
31 |
32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33 | # LICENSE is in incl_licenses directory.
34 | class Snake(nn.Module):
35 | '''
36 | Implementation of a sine-based periodic activation function
37 | Shape:
38 | - Input: (B, C, T)
39 | - Output: (B, C, T), same shape as the input
40 | Parameters:
41 | - alpha - trainable parameter
42 | References:
43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44 | https://arxiv.org/abs/2006.08195
45 | Examples:
46 | >>> a1 = snake(256)
47 | >>> x = torch.randn(256)
48 | >>> x = a1(x)
49 | '''
50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51 | '''
52 | Initialization.
53 | INPUT:
54 | - in_features: shape of the input
55 | - alpha: trainable parameter
56 | alpha is initialized to 1 by default, higher values = higher-frequency.
57 | alpha will be trained along with the rest of your model.
58 | '''
59 | super(Snake, self).__init__()
60 | self.in_features = in_features
61 |
62 | # initialize alpha
63 | self.alpha_logscale = alpha_logscale
64 | if self.alpha_logscale: # log scale alphas initialized to zeros
65 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
66 | else: # linear scale alphas initialized to ones
67 | self.alpha = Parameter(torch.ones(in_features) * alpha)
68 |
69 | self.alpha.requires_grad = alpha_trainable
70 |
71 | self.no_div_by_zero = 0.000000001
72 |
73 | def forward(self, x):
74 | '''
75 | Forward pass of the function.
76 | Applies the function to the input elementwise.
77 | Snake ∶= x + 1/a * sin^2 (xa)
78 | '''
79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80 | if self.alpha_logscale:
81 | alpha = torch.exp(alpha)
82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2 | # 2024 Alibaba Inc (Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # Modified from ESPnet(https://github.com/espnet/espnet)
16 | """ConvolutionModule definition."""
17 |
18 | from typing import Tuple
19 |
20 | import torch
21 | from torch import nn
22 |
23 |
24 | class ConvolutionModule(nn.Module):
25 | """ConvolutionModule in Conformer model."""
26 |
27 | def __init__(self,
28 | channels: int,
29 | kernel_size: int = 15,
30 | activation: nn.Module = nn.ReLU(),
31 | norm: str = "batch_norm",
32 | causal: bool = False,
33 | bias: bool = True):
34 | """Construct an ConvolutionModule object.
35 | Args:
36 | channels (int): The number of channels of conv layers.
37 | kernel_size (int): Kernel size of conv layers.
38 | causal (int): Whether use causal convolution or not
39 | """
40 | super().__init__()
41 |
42 | self.pointwise_conv1 = nn.Conv1d(
43 | channels,
44 | 2 * channels,
45 | kernel_size=1,
46 | stride=1,
47 | padding=0,
48 | bias=bias,
49 | )
50 | # self.lorder is used to distinguish if it's a causal convolution,
51 | # if self.lorder > 0: it's a causal convolution, the input will be
52 | # padded with self.lorder frames on the left in forward.
53 | # else: it's a symmetrical convolution
54 | if causal:
55 | padding = 0
56 | self.lorder = kernel_size - 1
57 | else:
58 | # kernel_size should be an odd number for none causal convolution
59 | assert (kernel_size - 1) % 2 == 0
60 | padding = (kernel_size - 1) // 2
61 | self.lorder = 0
62 | self.depthwise_conv = nn.Conv1d(
63 | channels,
64 | channels,
65 | kernel_size,
66 | stride=1,
67 | padding=padding,
68 | groups=channels,
69 | bias=bias,
70 | )
71 |
72 | assert norm in ['batch_norm', 'layer_norm']
73 | if norm == "batch_norm":
74 | self.use_layer_norm = False
75 | self.norm = nn.BatchNorm1d(channels)
76 | else:
77 | self.use_layer_norm = True
78 | self.norm = nn.LayerNorm(channels)
79 |
80 | self.pointwise_conv2 = nn.Conv1d(
81 | channels,
82 | channels,
83 | kernel_size=1,
84 | stride=1,
85 | padding=0,
86 | bias=bias,
87 | )
88 | self.activation = activation
89 |
90 | def forward(
91 | self,
92 | x: torch.Tensor,
93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94 | cache: torch.Tensor = torch.zeros((0, 0, 0)),
95 | ) -> Tuple[torch.Tensor, torch.Tensor]:
96 | """Compute convolution module.
97 | Args:
98 | x (torch.Tensor): Input tensor (#batch, time, channels).
99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100 | (0, 0, 0) means fake mask.
101 | cache (torch.Tensor): left context cache, it is only
102 | used in causal convolution (#batch, channels, cache_t),
103 | (0, 0, 0) meas fake cache.
104 | Returns:
105 | torch.Tensor: Output tensor (#batch, time, channels).
106 | """
107 | # exchange the temporal dimension and the feature dimension
108 | x = x.transpose(1, 2) # (#batch, channels, time)
109 |
110 | # mask batch padding
111 | if mask_pad.size(2) > 0: # time > 0
112 | x.masked_fill_(~mask_pad, 0.0)
113 |
114 | if self.lorder > 0:
115 | if cache.size(2) == 0: # cache_t == 0
116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117 | else:
118 | assert cache.size(0) == x.size(0) # equal batch
119 | assert cache.size(1) == x.size(1) # equal channel
120 | x = torch.cat((cache, x), dim=2)
121 | assert (x.size(2) > self.lorder)
122 | new_cache = x[:, :, -self.lorder:]
123 | else:
124 | # It's better we just return None if no cache is required,
125 | # However, for JIT export, here we just fake one tensor instead of
126 | # None.
127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128 |
129 | # GLU mechanism
130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132 |
133 | # 1D Depthwise Conv
134 | x = self.depthwise_conv(x)
135 | if self.use_layer_norm:
136 | x = x.transpose(1, 2)
137 | x = self.activation(self.norm(x))
138 | if self.use_layer_norm:
139 | x = x.transpose(1, 2)
140 | x = self.pointwise_conv2(x)
141 | # mask batch padding
142 | if mask_pad.size(2) > 0: # time > 0
143 | x.masked_fill_(~mask_pad, 0.0)
144 |
145 | return x.transpose(1, 2), new_cache
146 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/decoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Decoder self-attention layer definition."""
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 | from torch import nn
20 |
21 |
22 | class DecoderLayer(nn.Module):
23 | """Single decoder layer module.
24 |
25 | Args:
26 | size (int): Input dimension.
27 | self_attn (torch.nn.Module): Self-attention module instance.
28 | `MultiHeadedAttention` instance can be used as the argument.
29 | src_attn (torch.nn.Module): Inter-attention module instance.
30 | `MultiHeadedAttention` instance can be used as the argument.
31 | If `None` is passed, Inter-attention is not used, such as
32 | CIF, GPT, and other decoder only model.
33 | feed_forward (torch.nn.Module): Feed-forward module instance.
34 | `PositionwiseFeedForward` instance can be used as the argument.
35 | dropout_rate (float): Dropout rate.
36 | normalize_before (bool):
37 | True: use layer_norm before each sub-block.
38 | False: to use layer_norm after each sub-block.
39 | """
40 |
41 | def __init__(
42 | self,
43 | size: int,
44 | self_attn: nn.Module,
45 | src_attn: Optional[nn.Module],
46 | feed_forward: nn.Module,
47 | dropout_rate: float,
48 | normalize_before: bool = True,
49 | ):
50 | """Construct an DecoderLayer object."""
51 | super().__init__()
52 | self.size = size
53 | self.self_attn = self_attn
54 | self.src_attn = src_attn
55 | self.feed_forward = feed_forward
56 | self.norm1 = nn.LayerNorm(size, eps=1e-5)
57 | self.norm2 = nn.LayerNorm(size, eps=1e-5)
58 | self.norm3 = nn.LayerNorm(size, eps=1e-5)
59 | self.dropout = nn.Dropout(dropout_rate)
60 | self.normalize_before = normalize_before
61 |
62 | def forward(
63 | self,
64 | tgt: torch.Tensor,
65 | tgt_mask: torch.Tensor,
66 | memory: torch.Tensor,
67 | memory_mask: torch.Tensor,
68 | cache: Optional[torch.Tensor] = None
69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70 | """Compute decoded features.
71 |
72 | Args:
73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74 | tgt_mask (torch.Tensor): Mask for input tensor
75 | (#batch, maxlen_out).
76 | memory (torch.Tensor): Encoded memory
77 | (#batch, maxlen_in, size).
78 | memory_mask (torch.Tensor): Encoded memory mask
79 | (#batch, maxlen_in).
80 | cache (torch.Tensor): cached tensors.
81 | (#batch, maxlen_out - 1, size).
82 |
83 | Returns:
84 | torch.Tensor: Output tensor (#batch, maxlen_out, size).
85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88 |
89 | """
90 | residual = tgt
91 | if self.normalize_before:
92 | tgt = self.norm1(tgt)
93 |
94 | if cache is None:
95 | tgt_q = tgt
96 | tgt_q_mask = tgt_mask
97 | else:
98 | # compute only the last frame query keeping dim: max_time_out -> 1
99 | assert cache.shape == (
100 | tgt.shape[0],
101 | tgt.shape[1] - 1,
102 | self.size,
103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104 | tgt_q = tgt[:, -1:, :]
105 | residual = residual[:, -1:, :]
106 | tgt_q_mask = tgt_mask[:, -1:, :]
107 |
108 | x = residual + self.dropout(
109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110 | if not self.normalize_before:
111 | x = self.norm1(x)
112 |
113 | if self.src_attn is not None:
114 | residual = x
115 | if self.normalize_before:
116 | x = self.norm2(x)
117 | x = residual + self.dropout(
118 | self.src_attn(x, memory, memory, memory_mask)[0])
119 | if not self.normalize_before:
120 | x = self.norm2(x)
121 |
122 | residual = x
123 | if self.normalize_before:
124 | x = self.norm3(x)
125 | x = residual + self.dropout(self.feed_forward(x))
126 | if not self.normalize_before:
127 | x = self.norm3(x)
128 |
129 | if cache is not None:
130 | x = torch.cat([cache, x], dim=1)
131 |
132 | return x, tgt_mask, memory, memory_mask
133 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/label_smoothing_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Label smoothing module."""
16 |
17 | import torch
18 | from torch import nn
19 |
20 |
21 | class LabelSmoothingLoss(nn.Module):
22 | """Label-smoothing loss.
23 |
24 | In a standard CE loss, the label's data distribution is:
25 | [0,1,2] ->
26 | [
27 | [1.0, 0.0, 0.0],
28 | [0.0, 1.0, 0.0],
29 | [0.0, 0.0, 1.0],
30 | ]
31 |
32 | In the smoothing version CE Loss,some probabilities
33 | are taken from the true label prob (1.0) and are divided
34 | among other labels.
35 |
36 | e.g.
37 | smoothing=0.1
38 | [0,1,2] ->
39 | [
40 | [0.9, 0.05, 0.05],
41 | [0.05, 0.9, 0.05],
42 | [0.05, 0.05, 0.9],
43 | ]
44 |
45 | Args:
46 | size (int): the number of class
47 | padding_idx (int): padding class id which will be ignored for loss
48 | smoothing (float): smoothing rate (0.0 means the conventional CE)
49 | normalize_length (bool):
50 | normalize loss by sequence length if True
51 | normalize loss by batch size if False
52 | """
53 |
54 | def __init__(self,
55 | size: int,
56 | padding_idx: int,
57 | smoothing: float,
58 | normalize_length: bool = False):
59 | """Construct an LabelSmoothingLoss object."""
60 | super(LabelSmoothingLoss, self).__init__()
61 | self.criterion = nn.KLDivLoss(reduction="none")
62 | self.padding_idx = padding_idx
63 | self.confidence = 1.0 - smoothing
64 | self.smoothing = smoothing
65 | self.size = size
66 | self.normalize_length = normalize_length
67 |
68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69 | """Compute loss between x and target.
70 |
71 | The model outputs and data labels tensors are flatten to
72 | (batch*seqlen, class) shape and a mask is applied to the
73 | padding part which should not be calculated for loss.
74 |
75 | Args:
76 | x (torch.Tensor): prediction (batch, seqlen, class)
77 | target (torch.Tensor):
78 | target signal masked with self.padding_id (batch, seqlen)
79 | Returns:
80 | loss (torch.Tensor) : The KL loss, scalar float value
81 | """
82 | assert x.size(2) == self.size
83 | batch_size = x.size(0)
84 | x = x.view(-1, self.size)
85 | target = target.view(-1)
86 | # use zeros_like instead of torch.no_grad() for true_dist,
87 | # since no_grad() can not be exported by JIT
88 | true_dist = torch.zeros_like(x)
89 | true_dist.fill_(self.smoothing / (self.size - 1))
90 | ignore = target == self.padding_idx # (B,)
91 | total = len(target) - ignore.sum().item()
92 | target = target.masked_fill(ignore, 0) # avoid -1 index
93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95 | denom = total if self.normalize_length else batch_size
96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
97 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/positionwise_feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Positionwise feed forward layer definition."""
16 |
17 | import torch
18 |
19 |
20 | class PositionwiseFeedForward(torch.nn.Module):
21 | """Positionwise feed forward layer.
22 |
23 | FeedForward are appied on each position of the sequence.
24 | The output dim is same with the input dim.
25 |
26 | Args:
27 | idim (int): Input dimenstion.
28 | hidden_units (int): The number of hidden units.
29 | dropout_rate (float): Dropout rate.
30 | activation (torch.nn.Module): Activation function
31 | """
32 |
33 | def __init__(
34 | self,
35 | idim: int,
36 | hidden_units: int,
37 | dropout_rate: float,
38 | activation: torch.nn.Module = torch.nn.ReLU(),
39 | ):
40 | """Construct a PositionwiseFeedForward object."""
41 | super(PositionwiseFeedForward, self).__init__()
42 | self.w_1 = torch.nn.Linear(idim, hidden_units)
43 | self.activation = activation
44 | self.dropout = torch.nn.Dropout(dropout_rate)
45 | self.w_2 = torch.nn.Linear(hidden_units, idim)
46 |
47 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
48 | """Forward function.
49 |
50 | Args:
51 | xs: input tensor (B, L, D)
52 | Returns:
53 | output tensor, (B, L, D)
54 | """
55 | return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56 |
57 |
58 | class MoEFFNLayer(torch.nn.Module):
59 | """
60 | Mixture of expert with Positionwise feed forward layer
61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62 | The output dim is same with the input dim.
63 |
64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66 | Args:
67 | n_expert: number of expert.
68 | n_expert_per_token: The actual number of experts used for each frame
69 | idim (int): Input dimenstion.
70 | hidden_units (int): The number of hidden units.
71 | dropout_rate (float): Dropout rate.
72 | activation (torch.nn.Module): Activation function
73 | """
74 |
75 | def __init__(
76 | self,
77 | n_expert: int,
78 | n_expert_per_token: int,
79 | idim: int,
80 | hidden_units: int,
81 | dropout_rate: float,
82 | activation: torch.nn.Module = torch.nn.ReLU(),
83 | ):
84 | super(MoEFFNLayer, self).__init__()
85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86 | self.experts = torch.nn.ModuleList(
87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88 | activation) for _ in range(n_expert))
89 | self.n_expert_per_token = n_expert_per_token
90 |
91 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
92 | """Foward function.
93 | Args:
94 | xs: input tensor (B, L, D)
95 | Returns:
96 | output tensor, (B, L, D)
97 |
98 | """
99 | B, L, D = xs.size(
100 | ) # batch size, sequence length, embedding dimension (idim)
101 | xs = xs.view(-1, D) # (B*L, D)
102 | router = self.gate(xs) # (B*L, n_expert)
103 | logits, indices = torch.topk(
104 | router, self.n_expert_per_token
105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106 | weights = torch.nn.functional.softmax(
107 | logits, dim=1,
108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109 | output = torch.zeros_like(xs) # (B*L, D)
110 | for i, expert in enumerate(self.experts):
111 | mask = indices == i
112 | batch_idx, ith_expert = torch.where(mask)
113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114 | xs[batch_idx])
115 | return output.view(B, L, D)
116 |
--------------------------------------------------------------------------------
/cosyvoice/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/cosyvoice/utils/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/utils/class_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright [2023-11-28]
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import torch
16 |
17 | from cosyvoice.transformer.activation import Swish
18 | from cosyvoice.transformer.subsampling import (
19 | LinearNoSubsampling,
20 | EmbedinigNoSubsampling,
21 | Conv1dSubsampling2,
22 | Conv2dSubsampling4,
23 | Conv2dSubsampling6,
24 | Conv2dSubsampling8,
25 | )
26 | from cosyvoice.transformer.embedding import (PositionalEncoding,
27 | RelPositionalEncoding,
28 | WhisperPositionalEncoding,
29 | LearnablePositionalEncoding,
30 | NoPositionalEncoding)
31 | from cosyvoice.transformer.attention import (MultiHeadedAttention,
32 | RelPositionMultiHeadedAttention)
33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35 |
36 |
37 | COSYVOICE_ACTIVATION_CLASSES = {
38 | "hardtanh": torch.nn.Hardtanh,
39 | "tanh": torch.nn.Tanh,
40 | "relu": torch.nn.ReLU,
41 | "selu": torch.nn.SELU,
42 | "swish": getattr(torch.nn, "SiLU", Swish),
43 | "gelu": torch.nn.GELU,
44 | }
45 |
46 | COSYVOICE_SUBSAMPLE_CLASSES = {
47 | "linear": LinearNoSubsampling,
48 | "linear_legacy": LegacyLinearNoSubsampling,
49 | "embed": EmbedinigNoSubsampling,
50 | "conv1d2": Conv1dSubsampling2,
51 | "conv2d": Conv2dSubsampling4,
52 | "conv2d6": Conv2dSubsampling6,
53 | "conv2d8": Conv2dSubsampling8,
54 | 'paraformer_dummy': torch.nn.Identity
55 | }
56 |
57 | COSYVOICE_EMB_CLASSES = {
58 | "embed": PositionalEncoding,
59 | "abs_pos": PositionalEncoding,
60 | "rel_pos": RelPositionalEncoding,
61 | "rel_pos_espnet": EspnetRelPositionalEncoding,
62 | "no_pos": NoPositionalEncoding,
63 | "abs_pos_whisper": WhisperPositionalEncoding,
64 | "embed_learnable_pe": LearnablePositionalEncoding,
65 | }
66 |
67 | COSYVOICE_ATTENTION_CLASSES = {
68 | "selfattn": MultiHeadedAttention,
69 | "rel_selfattn": RelPositionMultiHeadedAttention,
70 | }
71 |
--------------------------------------------------------------------------------
/cosyvoice/utils/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # Modified from ESPnet(https://github.com/espnet/espnet)
16 | """Unility functions for Transformer."""
17 |
18 | import random
19 | from typing import List
20 |
21 | import numpy as np
22 | import torch
23 |
24 | IGNORE_ID = -1
25 |
26 |
27 | def pad_list(xs: List[torch.Tensor], pad_value: int):
28 | """Perform padding for the list of tensors.
29 |
30 | Args:
31 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
32 | pad_value (float): Value for padding.
33 |
34 | Returns:
35 | Tensor: Padded tensor (B, Tmax, `*`).
36 |
37 | Examples:
38 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
39 | >>> x
40 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
41 | >>> pad_list(x, 0)
42 | tensor([[1., 1., 1., 1.],
43 | [1., 1., 0., 0.],
44 | [1., 0., 0., 0.]])
45 |
46 | """
47 | max_len = max([len(item) for item in xs])
48 | batchs = len(xs)
49 | ndim = xs[0].ndim
50 | if ndim == 1:
51 | pad_res = torch.zeros(batchs,
52 | max_len,
53 | dtype=xs[0].dtype,
54 | device=xs[0].device)
55 | elif ndim == 2:
56 | pad_res = torch.zeros(batchs,
57 | max_len,
58 | xs[0].shape[1],
59 | dtype=xs[0].dtype,
60 | device=xs[0].device)
61 | elif ndim == 3:
62 | pad_res = torch.zeros(batchs,
63 | max_len,
64 | xs[0].shape[1],
65 | xs[0].shape[2],
66 | dtype=xs[0].dtype,
67 | device=xs[0].device)
68 | else:
69 | raise ValueError(f"Unsupported ndim: {ndim}")
70 | pad_res.fill_(pad_value)
71 | for i in range(batchs):
72 | pad_res[i, :len(xs[i])] = xs[i]
73 | return pad_res
74 |
75 |
76 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
77 | ignore_label: int) -> torch.Tensor:
78 | """Calculate accuracy.
79 |
80 | Args:
81 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
82 | pad_targets (LongTensor): Target label tensors (B, Lmax).
83 | ignore_label (int): Ignore label id.
84 |
85 | Returns:
86 | torch.Tensor: Accuracy value (0.0 - 1.0).
87 |
88 | """
89 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
90 | pad_outputs.size(1)).argmax(2)
91 | mask = pad_targets != ignore_label
92 | numerator = torch.sum(
93 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
94 | denominator = torch.sum(mask)
95 | return (numerator / denominator).detach()
96 |
97 |
98 | def get_padding(kernel_size, dilation=1):
99 | return int((kernel_size * dilation - dilation) / 2)
100 |
101 |
102 | def init_weights(m, mean=0.0, std=0.01):
103 | classname = m.__class__.__name__
104 | if classname.find("Conv") != -1:
105 | m.weight.data.normal_(mean, std)
106 |
107 |
108 | # Repetition Aware Sampling in VALL-E 2
109 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
110 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
111 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
112 | if rep_num >= win_size * tau_r:
113 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
114 | return top_ids
115 |
116 |
117 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
118 | prob, indices = [], []
119 | cum_prob = 0.0
120 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
121 | for i in range(len(sorted_idx)):
122 | # sampling both top-p and numbers.
123 | if cum_prob < top_p and len(prob) < top_k:
124 | cum_prob += sorted_value[i]
125 | prob.append(sorted_value[i])
126 | indices.append(sorted_idx[i])
127 | else:
128 | break
129 | prob = torch.tensor(prob).to(weighted_scores)
130 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
131 | top_ids = indices[prob.multinomial(1, replacement=True)]
132 | return top_ids
133 |
134 |
135 | def random_sampling(weighted_scores, decoded_tokens, sampling):
136 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
137 | return top_ids
138 |
139 |
140 | def fade_in_out(fade_in_mel, fade_out_mel, window):
141 | device = fade_in_mel.device
142 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
143 | mel_overlap_len = int(window.shape[0] / 2)
144 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
145 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
146 | return fade_in_mel.to(device)
147 |
148 |
149 | def set_all_random_seed(seed):
150 | random.seed(seed)
151 | np.random.seed(seed)
152 | torch.manual_seed(seed)
153 | torch.cuda.manual_seed_all(seed)
154 |
--------------------------------------------------------------------------------
/cosyvoice/utils/executor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | from contextlib import nullcontext
18 | import os
19 |
20 | import torch
21 | import torch.distributed as dist
22 |
23 | from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24 |
25 |
26 | class Executor:
27 |
28 | def __init__(self):
29 | self.step = 0
30 | self.epoch = 0
31 | self.rank = int(os.environ.get('RANK', 0))
32 | self.device = torch.device('cuda:{}'.format(self.rank))
33 |
34 | def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join):
35 | ''' Train one epoch
36 | '''
37 |
38 | lr = optimizer.param_groups[0]['lr']
39 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
40 | logging.info('using accumulate grad, new batch size is {} times'
41 | ' larger than before'.format(info_dict['accum_grad']))
42 | # A context manager to be used in conjunction with an instance of
43 | # torch.nn.parallel.DistributedDataParallel to be able to train
44 | # with uneven inputs across participating processes.
45 | model.train()
46 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
47 | with model_context():
48 | for batch_idx, batch_dict in enumerate(train_data_loader):
49 | info_dict["tag"] = "TRAIN"
50 | info_dict["step"] = self.step
51 | info_dict["epoch"] = self.epoch
52 | info_dict["batch_idx"] = batch_idx
53 | if cosyvoice_join(group_join, info_dict):
54 | break
55 |
56 | # Disable gradient synchronizations across DDP processes.
57 | # Within this context, gradients will be accumulated on module
58 | # variables, which will later be synchronized.
59 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
60 | context = model.no_sync
61 | # Used for single gpu training and DDP gradient synchronization
62 | # processes.
63 | else:
64 | context = nullcontext
65 |
66 | with context():
67 | info_dict = batch_forward(model, batch_dict, info_dict)
68 | info_dict = batch_backward(model, info_dict)
69 |
70 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71 | log_per_step(writer, info_dict)
72 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
73 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
74 | (batch_idx + 1) % info_dict["accum_grad"] == 0:
75 | dist.barrier()
76 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
77 | model.train()
78 | if (batch_idx + 1) % info_dict["accum_grad"] == 0:
79 | self.step += 1
80 | dist.barrier()
81 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
82 |
83 | @torch.inference_mode()
84 | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
85 | ''' Cross validation on
86 | '''
87 | logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
88 | model.eval()
89 | total_num_utts, total_loss_dict = 0, {} # avoid division by 0
90 | for batch_idx, batch_dict in enumerate(cv_data_loader):
91 | info_dict["tag"] = "CV"
92 | info_dict["step"] = self.step
93 | info_dict["epoch"] = self.epoch
94 | info_dict["batch_idx"] = batch_idx
95 |
96 | num_utts = len(batch_dict["utts"])
97 | total_num_utts += num_utts
98 |
99 | info_dict = batch_forward(model, batch_dict, info_dict)
100 |
101 | for k, v in info_dict['loss_dict'].items():
102 | if k not in total_loss_dict:
103 | total_loss_dict[k] = []
104 | total_loss_dict[k].append(v.item() * num_utts)
105 | log_per_step(None, info_dict)
106 | for k, v in total_loss_dict.items():
107 | total_loss_dict[k] = sum(v) / total_num_utts
108 | info_dict['loss_dict'] = total_loss_dict
109 | log_per_save(writer, info_dict)
110 | model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
111 | save_model(model, model_name, info_dict)
112 |
--------------------------------------------------------------------------------
/cosyvoice/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | import torchaudio
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | logging.basicConfig(level=logging.DEBUG,
21 | format='%(asctime)s %(levelname)s %(message)s')
22 |
23 |
24 | def read_lists(list_file):
25 | lists = []
26 | with open(list_file, 'r', encoding='utf8') as fin:
27 | for line in fin:
28 | lists.append(line.strip())
29 | return lists
30 |
31 |
32 | def read_json_lists(list_file):
33 | lists = read_lists(list_file)
34 | results = {}
35 | for fn in lists:
36 | with open(fn, 'r', encoding='utf8') as fin:
37 | results.update(json.load(fin))
38 | return results
39 |
40 |
41 | def load_wav(wav, target_sr):
42 | speech, sample_rate = torchaudio.load(wav)
43 | speech = speech.mean(dim=0, keepdim=True)
44 | if sample_rate != target_sr:
45 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
46 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
47 | return speech
48 |
--------------------------------------------------------------------------------
/cosyvoice/utils/frontend_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17 |
18 |
19 | # whether contain chinese character
20 | def contains_chinese(text):
21 | return bool(chinese_char_pattern.search(text))
22 |
23 |
24 | # replace special symbol
25 | def replace_corner_mark(text):
26 | text = text.replace('²', '平方')
27 | text = text.replace('³', '立方')
28 | return text
29 |
30 |
31 | # remove meaningless symbol
32 | def remove_bracket(text):
33 | text = text.replace('(', '').replace(')', '')
34 | text = text.replace('【', '').replace('】', '')
35 | text = text.replace('`', '').replace('`', '')
36 | text = text.replace("——", " ")
37 | return text
38 |
39 |
40 | # spell Arabic numerals
41 | def spell_out_number(text: str, inflect_parser):
42 | new_text = []
43 | st = None
44 | for i, c in enumerate(text):
45 | if not c.isdigit():
46 | if st is not None:
47 | num_str = inflect_parser.number_to_words(text[st: i])
48 | new_text.append(num_str)
49 | st = None
50 | new_text.append(c)
51 | else:
52 | if st is None:
53 | st = i
54 | if st is not None and st < len(text):
55 | num_str = inflect_parser.number_to_words(text[st:])
56 | new_text.append(num_str)
57 | return ''.join(new_text)
58 |
59 |
60 | # split paragrah logic:
61 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
62 | # 2. cal sentence len according to lang
63 | # 3. split sentence according to puncatation
64 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
65 | def calc_utt_length(_text: str):
66 | if lang == "zh":
67 | return len(_text)
68 | else:
69 | return len(tokenize(_text))
70 |
71 | def should_merge(_text: str):
72 | if lang == "zh":
73 | return len(_text) < merge_len
74 | else:
75 | return len(tokenize(_text)) < merge_len
76 |
77 | if lang == "zh":
78 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
79 | else:
80 | pounc = ['.', '?', '!', ';', ':']
81 | if comma_split:
82 | pounc.extend([',', ','])
83 |
84 | if text[-1] not in pounc:
85 | if lang == "zh":
86 | text += "。"
87 | else:
88 | text += "."
89 |
90 | st = 0
91 | utts = []
92 | for i, c in enumerate(text):
93 | if c in pounc:
94 | if len(text[st: i]) > 0:
95 | utts.append(text[st: i] + c)
96 | if i + 1 < len(text) and text[i + 1] in ['"', '”']:
97 | tmp = utts.pop(-1)
98 | utts.append(tmp + text[i + 1])
99 | st = i + 2
100 | else:
101 | st = i + 1
102 |
103 | final_utts = []
104 | cur_utt = ""
105 | for utt in utts:
106 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
107 | final_utts.append(cur_utt)
108 | cur_utt = ""
109 | cur_utt = cur_utt + utt
110 | if len(cur_utt) > 0:
111 | if should_merge(cur_utt) and len(final_utts) != 0:
112 | final_utts[-1] = final_utts[-1] + cur_utt
113 | else:
114 | final_utts.append(cur_utt)
115 |
116 | return final_utts
117 |
118 |
119 | # remove blank between chinese character
120 | def replace_blank(text: str):
121 | out_str = []
122 | for i, c in enumerate(text):
123 | if c == " ":
124 | if ((text[i + 1].isascii() and text[i + 1] != " ") and
125 | (text[i - 1].isascii() and text[i - 1] != " ")):
126 | out_str.append(c)
127 | else:
128 | out_str.append(c)
129 | return "".join(out_str)
130 |
--------------------------------------------------------------------------------
/funaudio_utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/funaudio_utils/__init__.py
--------------------------------------------------------------------------------
/funaudio_utils/cosyvoice_plus.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 14:21:08
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-04 16:07:20
7 | Description: file content
8 | '''
9 | from cosyvoice.cli.cosyvoice import CosyVoice
10 | from cosyvoice.utils.file_utils import logging
11 | from tqdm import tqdm
12 | import time
13 |
14 | class CosyVoicePlus(CosyVoice):
15 |
16 | def inference_zero_shot_with_spkmodel(self,tts_text, spkmodel,stream=False, speed=1.0):
17 | for i in tqdm(self.frontend.text_normalize(tts_text, split=True)):
18 | tts_text_token, tts_text_token_len = self.frontend._extract_text_token(tts_text)
19 | spkmodel["text"] = tts_text_token
20 | spkmodel["text_len"] = tts_text_token_len
21 | start_time = time.time()
22 | logging.info('synthesis text {}'.format(i))
23 | for model_output in self.model.tts(**spkmodel, stream=stream, speed=speed):
24 | speech_len = model_output['tts_speech'].shape[1] / 22050
25 | logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
26 | yield model_output
27 | start_time = time.time()
--------------------------------------------------------------------------------
/funaudio_utils/download_models.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 13:54:01
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-04 22:26:22
7 | Description: file content
8 | '''
9 | import modelscope
10 | import os
11 | import folder_paths
12 | from modelscope import snapshot_download
13 |
14 | # Download the model
15 | base_cosyvoice_model_path = os.path.join(folder_paths.models_dir, "CosyVoice")
16 | base_sensevoice_model_path = os.path.join(folder_paths.models_dir, "SenseVoice")
17 |
18 | def download_cosyvoice_300m(is_25hz=False):
19 | model_name = "CosyVoice-300M"
20 | model_id = "iic/CosyVoice-300M"
21 | if is_25hz:
22 | model_name = "CosyVoice-300M-25Hz"
23 | model_id = "iic/CosyVoice-300M-25Hz"
24 | model_dir = os.path.join(base_cosyvoice_model_path, model_name)
25 | snapshot_download(model_id=model_id, local_dir=model_dir)
26 | return model_name, model_dir
27 |
28 | def download_cosyvoice_300m_sft(is_25hz=False):
29 | model_name = "CosyVoice-300M-SFT"
30 | model_id = "iic/CosyVoice-300M-SFT"
31 | if is_25hz:
32 | model_name = "CosyVoice-300M-SFT-25Hz"
33 | model_id = "MachineS/CosyVoice-300M-SFT-25Hz"
34 | model_dir = os.path.join(base_cosyvoice_model_path, model_name)
35 | snapshot_download(model_id=model_id, local_dir=model_dir)
36 | return model_name, model_dir
37 |
38 | def download_sensevoice_small():
39 | model_name = "SenseVoiceSmall"
40 | model_id = "iic/SenseVoiceSmall"
41 | model_dir = os.path.join(base_sensevoice_model_path, model_name)
42 | snapshot_download(model_id=model_id, local_dir=model_dir)
43 | return model_name, model_dir
44 |
45 | def download_cosyvoice_300m_instruct():
46 | model_name = "CosyVoice-300M-Instruct"
47 | model_id = "iic/CosyVoice-300M-Instruct"
48 | model_dir = os.path.join(base_cosyvoice_model_path, model_name)
49 | snapshot_download(model_id=model_id, local_dir=model_dir)
50 | return model_name, model_dir
51 |
52 | def get_speaker_default_path():
53 | return os.path.join(base_cosyvoice_model_path, "Speaker")
--------------------------------------------------------------------------------
/funaudio_utils/pre.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 13:22:31
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-04 14:09:40
7 | Description: file content
8 | '''
9 | import librosa
10 | import torch
11 | import torchaudio
12 |
13 | class FunAudioLLMTool:
14 | def __init__(self):
15 | self.max_val = 0.8
16 | self.prompt_sr, self.target_sr = 16000, 22050
17 |
18 | def postprocess(self,speech, top_db=60, hop_length=220, win_length=440):
19 | speech, _ = librosa.effects.trim(
20 | speech, top_db=top_db,
21 | frame_length=win_length,
22 | hop_length=hop_length
23 | )
24 | if speech.abs().max() > self.max_val:
25 | speech = speech / speech.abs().max() * self.max_val
26 | speech = torch.concat([speech, torch.zeros(1, int(self.target_sr * 0.2))], dim=1)
27 | return speech
28 |
29 | def audio_resample(self, waveform, source_sr):
30 | waveform = waveform.squeeze(0)
31 | speech = waveform.mean(dim=0,keepdim=True)
32 | if source_sr != self.prompt_sr:
33 | speech = torchaudio.transforms.Resample(orig_freq=source_sr, new_freq=self.prompt_sr)(speech)
34 | return speech
35 |
--------------------------------------------------------------------------------
/matcha/VERSION:
--------------------------------------------------------------------------------
1 | 0.0.5.1
2 |
--------------------------------------------------------------------------------
/matcha/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/__init__.py
--------------------------------------------------------------------------------
/matcha/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/data/__init__.py
--------------------------------------------------------------------------------
/matcha/data/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/data/components/__init__.py
--------------------------------------------------------------------------------
/matcha/hifigan/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Jungil Kong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/matcha/hifigan/README.md:
--------------------------------------------------------------------------------
1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2 |
3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4 |
5 | In our [paper](https://arxiv.org/abs/2010.05646),
6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository.
8 |
9 | **Abstract :**
10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11 | Although such methods improve the sampling efficiency and memory usage,
12 | their sample quality has not yet reached that of autoregressive and flow-based generative models.
13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14 | As speech audio consists of sinusoidal signals with various periods,
15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21 |
22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23 |
24 | ## Pre-requisites
25 |
26 | 1. Python >= 3.6
27 | 2. Clone this repository.
28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30 | And move all wav files to `LJSpeech-1.1/wavs`
31 |
32 | ## Training
33 |
34 | ```
35 | python train.py --config config_v1.json
36 | ```
37 |
38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option.
41 |
42 | Validation loss during training with V1 generator.
43 | 
44 |
45 | ## Pretrained Model
46 |
47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows:
50 |
51 | | Folder Name | Generator | Dataset | Fine-Tuned |
52 | | ------------ | --------- | --------- | ------------------------------------------------------ |
53 | | LJ_V1 | V1 | LJSpeech | No |
54 | | LJ_V2 | V2 | LJSpeech | No |
55 | | LJ_V3 | V3 | LJSpeech | No |
56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
59 | | VCTK_V1 | V1 | VCTK | No |
60 | | VCTK_V2 | V2 | VCTK | No |
61 | | VCTK_V3 | V3 | VCTK | No |
62 | | UNIVERSAL_V1 | V1 | Universal | No |
63 |
64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
65 |
66 | ## Fine-Tuning
67 |
68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example:
71 | ` Audio File : LJ001-0001.wav
72 | Mel-Spectrogram File : LJ001-0001.npy`
73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command.
75 | ```
76 | python train.py --fine_tuning True --config config_v1.json
77 | ```
78 | For other command line options, please refer to the training section.
79 |
80 | ## Inference from wav file
81 |
82 | 1. Make `test_files` directory and copy wav files into the directory.
83 | 2. Run the following command.
84 | ` python inference.py --checkpoint_file [generator checkpoint file path]`
85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option.
87 |
88 | ## Inference for end-to-end speech synthesis
89 |
90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
93 | 2. Run the following command.
94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option.
97 |
98 | ## Acknowledgements
99 |
100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
102 |
--------------------------------------------------------------------------------
/matcha/hifigan/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/hifigan/__init__.py
--------------------------------------------------------------------------------
/matcha/hifigan/config.py:
--------------------------------------------------------------------------------
1 | v1 = {
2 | "resblock": "1",
3 | "num_gpus": 0,
4 | "batch_size": 16,
5 | "learning_rate": 0.0004,
6 | "adam_b1": 0.8,
7 | "adam_b2": 0.99,
8 | "lr_decay": 0.999,
9 | "seed": 1234,
10 | "upsample_rates": [8, 8, 2, 2],
11 | "upsample_kernel_sizes": [16, 16, 4, 4],
12 | "upsample_initial_channel": 512,
13 | "resblock_kernel_sizes": [3, 7, 11],
14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15 | "resblock_initial_channel": 256,
16 | "segment_size": 8192,
17 | "num_mels": 80,
18 | "num_freq": 1025,
19 | "n_fft": 1024,
20 | "hop_size": 256,
21 | "win_size": 1024,
22 | "sampling_rate": 22050,
23 | "fmin": 0,
24 | "fmax": 8000,
25 | "fmax_loss": None,
26 | "num_workers": 4,
27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
28 | }
29 |
--------------------------------------------------------------------------------
/matcha/hifigan/denoiser.py:
--------------------------------------------------------------------------------
1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2 |
3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4 | import torch
5 |
6 |
7 | class Denoiser(torch.nn.Module):
8 | """Removes model bias from audio produced with waveglow"""
9 |
10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
11 | super().__init__()
12 | self.filter_length = filter_length
13 | self.hop_length = int(filter_length / n_overlap)
14 | self.win_length = win_length
15 |
16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
17 | self.device = device
18 | if mode == "zeros":
19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
20 | elif mode == "normal":
21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
22 | else:
23 | raise Exception(f"Mode {mode} if not supported")
24 |
25 | def stft_fn(audio, n_fft, hop_length, win_length, window):
26 | spec = torch.stft(
27 | audio,
28 | n_fft=n_fft,
29 | hop_length=hop_length,
30 | win_length=win_length,
31 | window=window,
32 | return_complex=True,
33 | )
34 | spec = torch.view_as_real(spec)
35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
36 |
37 | self.stft = lambda x: stft_fn(
38 | audio=x,
39 | n_fft=self.filter_length,
40 | hop_length=self.hop_length,
41 | win_length=self.win_length,
42 | window=torch.hann_window(self.win_length, device=device),
43 | )
44 | self.istft = lambda x, y: torch.istft(
45 | torch.complex(x * torch.cos(y), x * torch.sin(y)),
46 | n_fft=self.filter_length,
47 | hop_length=self.hop_length,
48 | win_length=self.win_length,
49 | window=torch.hann_window(self.win_length, device=device),
50 | )
51 |
52 | with torch.no_grad():
53 | bias_audio = vocoder(mel_input).float().squeeze(0)
54 | bias_spec, _ = self.stft(bias_audio)
55 |
56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
57 |
58 | @torch.inference_mode()
59 | def forward(self, audio, strength=0.0005):
60 | audio_spec, audio_angles = self.stft(audio)
61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles)
64 | return audio_denoised
65 |
--------------------------------------------------------------------------------
/matcha/hifigan/env.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jik876/hifi-gan """
2 |
3 | import os
4 | import shutil
5 |
6 |
7 | class AttrDict(dict):
8 | def __init__(self, *args, **kwargs):
9 | super().__init__(*args, **kwargs)
10 | self.__dict__ = self
11 |
12 |
13 | def build_env(config, config_name, path):
14 | t_path = os.path.join(path, config_name)
15 | if config != t_path:
16 | os.makedirs(path, exist_ok=True)
17 | shutil.copyfile(config, os.path.join(path, config_name))
18 |
--------------------------------------------------------------------------------
/matcha/hifigan/meldataset.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jik876/hifi-gan """
2 |
3 | import math
4 | import os
5 | import random
6 |
7 | import numpy as np
8 | import torch
9 | import torch.utils.data
10 | from librosa.filters import mel as librosa_mel_fn
11 | from librosa.util import normalize
12 | from scipy.io.wavfile import read
13 |
14 | MAX_WAV_VALUE = 32768.0
15 |
16 |
17 | def load_wav(full_path):
18 | sampling_rate, data = read(full_path)
19 | return data, sampling_rate
20 |
21 |
22 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24 |
25 |
26 | def dynamic_range_decompression(x, C=1):
27 | return np.exp(x) / C
28 |
29 |
30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31 | return torch.log(torch.clamp(x, min=clip_val) * C)
32 |
33 |
34 | def dynamic_range_decompression_torch(x, C=1):
35 | return torch.exp(x) / C
36 |
37 |
38 | def spectral_normalize_torch(magnitudes):
39 | output = dynamic_range_compression_torch(magnitudes)
40 | return output
41 |
42 |
43 | def spectral_de_normalize_torch(magnitudes):
44 | output = dynamic_range_decompression_torch(magnitudes)
45 | return output
46 |
47 |
48 | mel_basis = {}
49 | hann_window = {}
50 |
51 |
52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53 | if torch.min(y) < -1.0:
54 | print("min value is ", torch.min(y))
55 | if torch.max(y) > 1.0:
56 | print("max value is ", torch.max(y))
57 |
58 | global mel_basis, hann_window # pylint: disable=global-statement
59 | if fmax not in mel_basis:
60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63 |
64 | y = torch.nn.functional.pad(
65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
66 | )
67 | y = y.squeeze(1)
68 |
69 | spec = torch.view_as_real(
70 | torch.stft(
71 | y,
72 | n_fft,
73 | hop_length=hop_size,
74 | win_length=win_size,
75 | window=hann_window[str(y.device)],
76 | center=center,
77 | pad_mode="reflect",
78 | normalized=False,
79 | onesided=True,
80 | return_complex=True,
81 | )
82 | )
83 |
84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85 |
86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87 | spec = spectral_normalize_torch(spec)
88 |
89 | return spec
90 |
91 |
92 | def get_dataset_filelist(a):
93 | with open(a.input_training_file, encoding="utf-8") as fi:
94 | training_files = [
95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
96 | ]
97 |
98 | with open(a.input_validation_file, encoding="utf-8") as fi:
99 | validation_files = [
100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
101 | ]
102 | return training_files, validation_files
103 |
104 |
105 | class MelDataset(torch.utils.data.Dataset):
106 | def __init__(
107 | self,
108 | training_files,
109 | segment_size,
110 | n_fft,
111 | num_mels,
112 | hop_size,
113 | win_size,
114 | sampling_rate,
115 | fmin,
116 | fmax,
117 | split=True,
118 | shuffle=True,
119 | n_cache_reuse=1,
120 | device=None,
121 | fmax_loss=None,
122 | fine_tuning=False,
123 | base_mels_path=None,
124 | ):
125 | self.audio_files = training_files
126 | random.seed(1234)
127 | if shuffle:
128 | random.shuffle(self.audio_files)
129 | self.segment_size = segment_size
130 | self.sampling_rate = sampling_rate
131 | self.split = split
132 | self.n_fft = n_fft
133 | self.num_mels = num_mels
134 | self.hop_size = hop_size
135 | self.win_size = win_size
136 | self.fmin = fmin
137 | self.fmax = fmax
138 | self.fmax_loss = fmax_loss
139 | self.cached_wav = None
140 | self.n_cache_reuse = n_cache_reuse
141 | self._cache_ref_count = 0
142 | self.device = device
143 | self.fine_tuning = fine_tuning
144 | self.base_mels_path = base_mels_path
145 |
146 | def __getitem__(self, index):
147 | filename = self.audio_files[index]
148 | if self._cache_ref_count == 0:
149 | audio, sampling_rate = load_wav(filename)
150 | audio = audio / MAX_WAV_VALUE
151 | if not self.fine_tuning:
152 | audio = normalize(audio) * 0.95
153 | self.cached_wav = audio
154 | if sampling_rate != self.sampling_rate:
155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
156 | self._cache_ref_count = self.n_cache_reuse
157 | else:
158 | audio = self.cached_wav
159 | self._cache_ref_count -= 1
160 |
161 | audio = torch.FloatTensor(audio)
162 | audio = audio.unsqueeze(0)
163 |
164 | if not self.fine_tuning:
165 | if self.split:
166 | if audio.size(1) >= self.segment_size:
167 | max_audio_start = audio.size(1) - self.segment_size
168 | audio_start = random.randint(0, max_audio_start)
169 | audio = audio[:, audio_start : audio_start + self.segment_size]
170 | else:
171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
172 |
173 | mel = mel_spectrogram(
174 | audio,
175 | self.n_fft,
176 | self.num_mels,
177 | self.sampling_rate,
178 | self.hop_size,
179 | self.win_size,
180 | self.fmin,
181 | self.fmax,
182 | center=False,
183 | )
184 | else:
185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
186 | mel = torch.from_numpy(mel)
187 |
188 | if len(mel.shape) < 3:
189 | mel = mel.unsqueeze(0)
190 |
191 | if self.split:
192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size)
193 |
194 | if audio.size(1) >= self.segment_size:
195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg]
197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
198 | else:
199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
201 |
202 | mel_loss = mel_spectrogram(
203 | audio,
204 | self.n_fft,
205 | self.num_mels,
206 | self.sampling_rate,
207 | self.hop_size,
208 | self.win_size,
209 | self.fmin,
210 | self.fmax_loss,
211 | center=False,
212 | )
213 |
214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
215 |
216 | def __len__(self):
217 | return len(self.audio_files)
218 |
--------------------------------------------------------------------------------
/matcha/hifigan/xutils.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jik876/hifi-gan """
2 |
3 | import glob
4 | import os
5 |
6 | import matplotlib
7 | import torch
8 | from torch.nn.utils import weight_norm
9 |
10 | matplotlib.use("Agg")
11 | import matplotlib.pylab as plt
12 |
13 |
14 | def plot_spectrogram(spectrogram):
15 | fig, ax = plt.subplots(figsize=(10, 2))
16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17 | plt.colorbar(im, ax=ax)
18 |
19 | fig.canvas.draw()
20 | plt.close()
21 |
22 | return fig
23 |
24 |
25 | def init_weights(m, mean=0.0, std=0.01):
26 | classname = m.__class__.__name__
27 | if classname.find("Conv") != -1:
28 | m.weight.data.normal_(mean, std)
29 |
30 |
31 | def apply_weight_norm(m):
32 | classname = m.__class__.__name__
33 | if classname.find("Conv") != -1:
34 | weight_norm(m)
35 |
36 |
37 | def get_padding(kernel_size, dilation=1):
38 | return int((kernel_size * dilation - dilation) / 2)
39 |
40 |
41 | def load_checkpoint(filepath, device):
42 | assert os.path.isfile(filepath)
43 | print(f"Loading '{filepath}'")
44 | checkpoint_dict = torch.load(filepath, map_location=device)
45 | print("Complete.")
46 | return checkpoint_dict
47 |
48 |
49 | def save_checkpoint(filepath, obj):
50 | print(f"Saving checkpoint to {filepath}")
51 | torch.save(obj, filepath)
52 | print("Complete.")
53 |
54 |
55 | def scan_checkpoint(cp_dir, prefix):
56 | pattern = os.path.join(cp_dir, prefix + "????????")
57 | cp_list = glob.glob(pattern)
58 | if len(cp_list) == 0:
59 | return None
60 | return sorted(cp_list)[-1]
61 |
--------------------------------------------------------------------------------
/matcha/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/models/__init__.py
--------------------------------------------------------------------------------
/matcha/models/baselightningmodule.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a base lightning module that can be used to train a model.
3 | The benefit of this abstraction is that all the logic outside of model definition can be reused for different models.
4 | """
5 | import inspect
6 | from abc import ABC
7 | from typing import Any, Dict
8 |
9 | import torch
10 | from lightning import LightningModule
11 | from lightning.pytorch.utilities import grad_norm
12 |
13 | from matcha import utils
14 | from matcha.utils.utils import plot_tensor
15 |
16 | log = utils.get_pylogger(__name__)
17 |
18 |
19 | class BaseLightningClass(LightningModule, ABC):
20 | def update_data_statistics(self, data_statistics):
21 | if data_statistics is None:
22 | data_statistics = {
23 | "mel_mean": 0.0,
24 | "mel_std": 1.0,
25 | }
26 |
27 | self.register_buffer("mel_mean", torch.tensor(data_statistics["mel_mean"]))
28 | self.register_buffer("mel_std", torch.tensor(data_statistics["mel_std"]))
29 |
30 | def configure_optimizers(self) -> Any:
31 | optimizer = self.hparams.optimizer(params=self.parameters())
32 | if self.hparams.scheduler not in (None, {}):
33 | scheduler_args = {}
34 | # Manage last epoch for exponential schedulers
35 | if "last_epoch" in inspect.signature(self.hparams.scheduler.scheduler).parameters:
36 | if hasattr(self, "ckpt_loaded_epoch"):
37 | current_epoch = self.ckpt_loaded_epoch - 1
38 | else:
39 | current_epoch = -1
40 |
41 | scheduler_args.update({"optimizer": optimizer})
42 | scheduler = self.hparams.scheduler.scheduler(**scheduler_args)
43 | scheduler.last_epoch = current_epoch
44 | return {
45 | "optimizer": optimizer,
46 | "lr_scheduler": {
47 | "scheduler": scheduler,
48 | "interval": self.hparams.scheduler.lightning_args.interval,
49 | "frequency": self.hparams.scheduler.lightning_args.frequency,
50 | "name": "learning_rate",
51 | },
52 | }
53 |
54 | return {"optimizer": optimizer}
55 |
56 | def get_losses(self, batch):
57 | x, x_lengths = batch["x"], batch["x_lengths"]
58 | y, y_lengths = batch["y"], batch["y_lengths"]
59 | spks = batch["spks"]
60 |
61 | dur_loss, prior_loss, diff_loss = self(
62 | x=x,
63 | x_lengths=x_lengths,
64 | y=y,
65 | y_lengths=y_lengths,
66 | spks=spks,
67 | out_size=self.out_size,
68 | )
69 | return {
70 | "dur_loss": dur_loss,
71 | "prior_loss": prior_loss,
72 | "diff_loss": diff_loss,
73 | }
74 |
75 | def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
76 | self.ckpt_loaded_epoch = checkpoint["epoch"] # pylint: disable=attribute-defined-outside-init
77 |
78 | def training_step(self, batch: Any, batch_idx: int):
79 | loss_dict = self.get_losses(batch)
80 | self.log(
81 | "step",
82 | float(self.global_step),
83 | on_step=True,
84 | prog_bar=True,
85 | logger=True,
86 | sync_dist=True,
87 | )
88 |
89 | self.log(
90 | "sub_loss/train_dur_loss",
91 | loss_dict["dur_loss"],
92 | on_step=True,
93 | on_epoch=True,
94 | logger=True,
95 | sync_dist=True,
96 | )
97 | self.log(
98 | "sub_loss/train_prior_loss",
99 | loss_dict["prior_loss"],
100 | on_step=True,
101 | on_epoch=True,
102 | logger=True,
103 | sync_dist=True,
104 | )
105 | self.log(
106 | "sub_loss/train_diff_loss",
107 | loss_dict["diff_loss"],
108 | on_step=True,
109 | on_epoch=True,
110 | logger=True,
111 | sync_dist=True,
112 | )
113 |
114 | total_loss = sum(loss_dict.values())
115 | self.log(
116 | "loss/train",
117 | total_loss,
118 | on_step=True,
119 | on_epoch=True,
120 | logger=True,
121 | prog_bar=True,
122 | sync_dist=True,
123 | )
124 |
125 | return {"loss": total_loss, "log": loss_dict}
126 |
127 | def validation_step(self, batch: Any, batch_idx: int):
128 | loss_dict = self.get_losses(batch)
129 | self.log(
130 | "sub_loss/val_dur_loss",
131 | loss_dict["dur_loss"],
132 | on_step=True,
133 | on_epoch=True,
134 | logger=True,
135 | sync_dist=True,
136 | )
137 | self.log(
138 | "sub_loss/val_prior_loss",
139 | loss_dict["prior_loss"],
140 | on_step=True,
141 | on_epoch=True,
142 | logger=True,
143 | sync_dist=True,
144 | )
145 | self.log(
146 | "sub_loss/val_diff_loss",
147 | loss_dict["diff_loss"],
148 | on_step=True,
149 | on_epoch=True,
150 | logger=True,
151 | sync_dist=True,
152 | )
153 |
154 | total_loss = sum(loss_dict.values())
155 | self.log(
156 | "loss/val",
157 | total_loss,
158 | on_step=True,
159 | on_epoch=True,
160 | logger=True,
161 | prog_bar=True,
162 | sync_dist=True,
163 | )
164 |
165 | return total_loss
166 |
167 | def on_validation_end(self) -> None:
168 | if self.trainer.is_global_zero:
169 | one_batch = next(iter(self.trainer.val_dataloaders))
170 | if self.current_epoch == 0:
171 | log.debug("Plotting original samples")
172 | for i in range(2):
173 | y = one_batch["y"][i].unsqueeze(0).to(self.device)
174 | self.logger.experiment.add_image(
175 | f"original/{i}",
176 | plot_tensor(y.squeeze().cpu()),
177 | self.current_epoch,
178 | dataformats="HWC",
179 | )
180 |
181 | log.debug("Synthesising...")
182 | for i in range(2):
183 | x = one_batch["x"][i].unsqueeze(0).to(self.device)
184 | x_lengths = one_batch["x_lengths"][i].unsqueeze(0).to(self.device)
185 | spks = one_batch["spks"][i].unsqueeze(0).to(self.device) if one_batch["spks"] is not None else None
186 | output = self.synthesise(x[:, :x_lengths], x_lengths, n_timesteps=10, spks=spks)
187 | y_enc, y_dec = output["encoder_outputs"], output["decoder_outputs"]
188 | attn = output["attn"]
189 | self.logger.experiment.add_image(
190 | f"generated_enc/{i}",
191 | plot_tensor(y_enc.squeeze().cpu()),
192 | self.current_epoch,
193 | dataformats="HWC",
194 | )
195 | self.logger.experiment.add_image(
196 | f"generated_dec/{i}",
197 | plot_tensor(y_dec.squeeze().cpu()),
198 | self.current_epoch,
199 | dataformats="HWC",
200 | )
201 | self.logger.experiment.add_image(
202 | f"alignment/{i}",
203 | plot_tensor(attn.squeeze().cpu()),
204 | self.current_epoch,
205 | dataformats="HWC",
206 | )
207 |
208 | def on_before_optimizer_step(self, optimizer):
209 | self.log_dict({f"grad_norm/{k}": v for k, v in grad_norm(self, norm_type=2).items()})
210 |
--------------------------------------------------------------------------------
/matcha/models/components/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/models/components/__init__.py
--------------------------------------------------------------------------------
/matcha/models/components/flow_matching.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from matcha.models.components.decoder import Decoder
7 | from matcha.utils.pylogger import get_pylogger
8 |
9 | log = get_pylogger(__name__)
10 |
11 |
12 | class BASECFM(torch.nn.Module, ABC):
13 | def __init__(
14 | self,
15 | n_feats,
16 | cfm_params,
17 | n_spks=1,
18 | spk_emb_dim=128,
19 | ):
20 | super().__init__()
21 | self.n_feats = n_feats
22 | self.n_spks = n_spks
23 | self.spk_emb_dim = spk_emb_dim
24 | self.solver = cfm_params.solver
25 | if hasattr(cfm_params, "sigma_min"):
26 | self.sigma_min = cfm_params.sigma_min
27 | else:
28 | self.sigma_min = 1e-4
29 |
30 | self.estimator = None
31 |
32 | @torch.inference_mode()
33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
34 | """Forward diffusion
35 |
36 | Args:
37 | mu (torch.Tensor): output of encoder
38 | shape: (batch_size, n_feats, mel_timesteps)
39 | mask (torch.Tensor): output_mask
40 | shape: (batch_size, 1, mel_timesteps)
41 | n_timesteps (int): number of diffusion steps
42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
43 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
44 | shape: (batch_size, spk_emb_dim)
45 | cond: Not used but kept for future purposes
46 |
47 | Returns:
48 | sample: generated mel-spectrogram
49 | shape: (batch_size, n_feats, mel_timesteps)
50 | """
51 | z = torch.randn_like(mu) * temperature
52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
54 |
55 | def solve_euler(self, x, t_span, mu, mask, spks, cond):
56 | """
57 | Fixed euler solver for ODEs.
58 | Args:
59 | x (torch.Tensor): random noise
60 | t_span (torch.Tensor): n_timesteps interpolated
61 | shape: (n_timesteps + 1,)
62 | mu (torch.Tensor): output of encoder
63 | shape: (batch_size, n_feats, mel_timesteps)
64 | mask (torch.Tensor): output_mask
65 | shape: (batch_size, 1, mel_timesteps)
66 | spks (torch.Tensor, optional): speaker ids. Defaults to None.
67 | shape: (batch_size, spk_emb_dim)
68 | cond: Not used but kept for future purposes
69 | """
70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
71 |
72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file
73 | # Or in future might add like a return_all_steps flag
74 | sol = []
75 |
76 | for step in range(1, len(t_span)):
77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond)
78 |
79 | x = x + dt * dphi_dt
80 | t = t + dt
81 | sol.append(x)
82 | if step < len(t_span) - 1:
83 | dt = t_span[step + 1] - t
84 |
85 | return sol[-1]
86 |
87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None):
88 | """Computes diffusion loss
89 |
90 | Args:
91 | x1 (torch.Tensor): Target
92 | shape: (batch_size, n_feats, mel_timesteps)
93 | mask (torch.Tensor): target mask
94 | shape: (batch_size, 1, mel_timesteps)
95 | mu (torch.Tensor): output of encoder
96 | shape: (batch_size, n_feats, mel_timesteps)
97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None.
98 | shape: (batch_size, spk_emb_dim)
99 |
100 | Returns:
101 | loss: conditional flow matching loss
102 | y: conditional flow
103 | shape: (batch_size, n_feats, mel_timesteps)
104 | """
105 | b, _, t = mu.shape
106 |
107 | # random timestep
108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
109 | # sample noise p(x_0)
110 | z = torch.randn_like(x1)
111 |
112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1
113 | u = x1 - (1 - self.sigma_min) * z
114 |
115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / (
116 | torch.sum(mask) * u.shape[1]
117 | )
118 | return loss, y
119 |
120 |
121 | class CFM(BASECFM):
122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64):
123 | super().__init__(
124 | n_feats=in_channels,
125 | cfm_params=cfm_params,
126 | n_spks=n_spks,
127 | spk_emb_dim=spk_emb_dim,
128 | )
129 |
130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0)
131 | # Just change the architecture of the estimator here
132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params)
133 |
--------------------------------------------------------------------------------
/matcha/onnx/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/matcha/onnx/__init__.py
--------------------------------------------------------------------------------
/matcha/onnx/export.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import torch
7 | from lightning import LightningModule
8 |
9 | from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder
10 |
11 | DEFAULT_OPSET = 15
12 |
13 | SEED = 1234
14 | random.seed(SEED)
15 | np.random.seed(SEED)
16 | torch.manual_seed(SEED)
17 | torch.cuda.manual_seed(SEED)
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 |
21 |
22 | class MatchaWithVocoder(LightningModule):
23 | def __init__(self, matcha, vocoder):
24 | super().__init__()
25 | self.matcha = matcha
26 | self.vocoder = vocoder
27 |
28 | def forward(self, x, x_lengths, scales, spks=None):
29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks)
30 | wavs = self.vocoder(mel).clamp(-1, 1)
31 | lengths = mel_lengths * 256
32 | return wavs.squeeze(1), lengths
33 |
34 |
35 | def get_exportable_module(matcha, vocoder, n_timesteps):
36 | """
37 | Return an appropriate `LighteningModule` and output-node names
38 | based on whether the vocoder is embedded in the final graph
39 | """
40 |
41 | def onnx_forward_func(x, x_lengths, scales, spks=None):
42 | """
43 | Custom forward function for accepting
44 | scaler parameters as tensors
45 | """
46 | # Extract scaler parameters from tensors
47 | temperature = scales[0]
48 | length_scale = scales[1]
49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale)
50 | return output["mel"], output["mel_lengths"]
51 |
52 | # Monkey-patch Matcha's forward function
53 | matcha.forward = onnx_forward_func
54 |
55 | if vocoder is None:
56 | model, output_names = matcha, ["mel", "mel_lengths"]
57 | else:
58 | model = MatchaWithVocoder(matcha, vocoder)
59 | output_names = ["wav", "wav_lengths"]
60 | return model, output_names
61 |
62 |
63 | def get_inputs(is_multi_speaker):
64 | """
65 | Create dummy inputs for tracing
66 | """
67 | dummy_input_length = 50
68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long)
69 | x_lengths = torch.LongTensor([dummy_input_length])
70 |
71 | # Scales
72 | temperature = 0.667
73 | length_scale = 1.0
74 | scales = torch.Tensor([temperature, length_scale])
75 |
76 | model_inputs = [x, x_lengths, scales]
77 | input_names = [
78 | "x",
79 | "x_lengths",
80 | "scales",
81 | ]
82 |
83 | if is_multi_speaker:
84 | spks = torch.LongTensor([1])
85 | model_inputs.append(spks)
86 | input_names.append("spks")
87 |
88 | return tuple(model_inputs), input_names
89 |
90 |
91 | def main():
92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX")
93 |
94 | parser.add_argument(
95 | "checkpoint_path",
96 | type=str,
97 | help="Path to the model checkpoint",
98 | )
99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file")
100 | parser.add_argument(
101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)"
102 | )
103 | parser.add_argument(
104 | "--vocoder-name",
105 | type=str,
106 | choices=list(VOCODER_URLS.keys()),
107 | default=None,
108 | help="Name of the vocoder to embed in the ONNX graph",
109 | )
110 | parser.add_argument(
111 | "--vocoder-checkpoint-path",
112 | type=str,
113 | default=None,
114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience",
115 | )
116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15")
117 |
118 | args = parser.parse_args()
119 |
120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}")
121 | print(f"Setting n_timesteps to {args.n_timesteps}")
122 |
123 | checkpoint_path = Path(args.checkpoint_path)
124 | matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu")
125 |
126 | if args.vocoder_name or args.vocoder_checkpoint_path:
127 | assert (
128 | args.vocoder_name and args.vocoder_checkpoint_path
129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph."
130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu")
131 | else:
132 | vocoder = None
133 |
134 | is_multi_speaker = matcha.n_spks > 1
135 |
136 | dummy_input, input_names = get_inputs(is_multi_speaker)
137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps)
138 |
139 | # Set dynamic shape for inputs/outputs
140 | dynamic_axes = {
141 | "x": {0: "batch_size", 1: "time"},
142 | "x_lengths": {0: "batch_size"},
143 | }
144 |
145 | if vocoder is None:
146 | dynamic_axes.update(
147 | {
148 | "mel": {0: "batch_size", 2: "time"},
149 | "mel_lengths": {0: "batch_size"},
150 | }
151 | )
152 | else:
153 | print("Embedding the vocoder in the ONNX graph")
154 | dynamic_axes.update(
155 | {
156 | "wav": {0: "batch_size", 1: "time"},
157 | "wav_lengths": {0: "batch_size"},
158 | }
159 | )
160 |
161 | if is_multi_speaker:
162 | dynamic_axes["spks"] = {0: "batch_size"}
163 |
164 | # Create the output directory (if not exists)
165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True)
166 |
167 | model.to_onnx(
168 | args.output,
169 | dummy_input,
170 | input_names=input_names,
171 | output_names=output_names,
172 | dynamic_axes=dynamic_axes,
173 | opset_version=args.opset,
174 | export_params=True,
175 | do_constant_folding=True,
176 | )
177 | print(f"[🍵] ONNX model exported to {args.output}")
178 |
179 |
180 | if __name__ == "__main__":
181 | main()
182 |
--------------------------------------------------------------------------------
/matcha/onnx/infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import warnings
4 | from pathlib import Path
5 | from time import perf_counter
6 |
7 | import numpy as np
8 | import onnxruntime as ort
9 | import soundfile as sf
10 | import torch
11 |
12 | from matcha.cli import plot_spectrogram_to_numpy, process_text
13 |
14 |
15 | def validate_args(args):
16 | assert (
17 | args.text or args.file
18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms."
19 | assert args.temperature >= 0, "Sampling temperature cannot be negative"
20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0"
21 | return args
22 |
23 |
24 | def write_wavs(model, inputs, output_dir, external_vocoder=None):
25 | if external_vocoder is None:
26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly")
27 | t0 = perf_counter()
28 | wavs, wav_lengths = model.run(None, inputs)
29 | infer_secs = perf_counter() - t0
30 | mel_infer_secs = vocoder_infer_secs = None
31 | else:
32 | print("[🍵] Generating mel using Matcha")
33 | mel_t0 = perf_counter()
34 | mels, mel_lengths = model.run(None, inputs)
35 | mel_infer_secs = perf_counter() - mel_t0
36 | print("Generating waveform from mel using external vocoder")
37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels}
38 | vocoder_t0 = perf_counter()
39 | wavs = external_vocoder.run(None, vocoder_inputs)[0]
40 | vocoder_infer_secs = perf_counter() - vocoder_t0
41 | wavs = wavs.squeeze(1)
42 | wav_lengths = mel_lengths * 256
43 | infer_secs = mel_infer_secs + vocoder_infer_secs
44 |
45 | output_dir = Path(output_dir)
46 | output_dir.mkdir(parents=True, exist_ok=True)
47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)):
48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav")
49 | audio = wav[:wav_length]
50 | print(f"Writing audio to {output_filename}")
51 | sf.write(output_filename, audio, 22050, "PCM_24")
52 |
53 | wav_secs = wav_lengths.sum() / 22050
54 | print(f"Inference seconds: {infer_secs}")
55 | print(f"Generated wav seconds: {wav_secs}")
56 | rtf = infer_secs / wav_secs
57 | if mel_infer_secs is not None:
58 | mel_rtf = mel_infer_secs / wav_secs
59 | print(f"Matcha RTF: {mel_rtf}")
60 | if vocoder_infer_secs is not None:
61 | vocoder_rtf = vocoder_infer_secs / wav_secs
62 | print(f"Vocoder RTF: {vocoder_rtf}")
63 | print(f"Overall RTF: {rtf}")
64 |
65 |
66 | def write_mels(model, inputs, output_dir):
67 | t0 = perf_counter()
68 | mels, mel_lengths = model.run(None, inputs)
69 | infer_secs = perf_counter() - t0
70 |
71 | output_dir = Path(output_dir)
72 | output_dir.mkdir(parents=True, exist_ok=True)
73 | for i, mel in enumerate(mels):
74 | output_stem = output_dir.joinpath(f"output_{i + 1}")
75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png"))
76 | np.save(output_stem.with_suffix(".numpy"), mel)
77 |
78 | wav_secs = (mel_lengths * 256).sum() / 22050
79 | print(f"Inference seconds: {infer_secs}")
80 | print(f"Generated wav seconds: {wav_secs}")
81 | rtf = infer_secs / wav_secs
82 | print(f"RTF: {rtf}")
83 |
84 |
85 | def main():
86 | parser = argparse.ArgumentParser(
87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching"
88 | )
89 | parser.add_argument(
90 | "model",
91 | type=str,
92 | help="ONNX model to use",
93 | )
94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)")
95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize")
96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize")
97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID")
98 | parser.add_argument(
99 | "--temperature",
100 | type=float,
101 | default=0.667,
102 | help="Variance of the x0 noise (default: 0.667)",
103 | )
104 | parser.add_argument(
105 | "--speaking-rate",
106 | type=float,
107 | default=1.0,
108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)",
109 | )
110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)")
111 | parser.add_argument(
112 | "--output-dir",
113 | type=str,
114 | default=os.getcwd(),
115 | help="Output folder to save results (default: current dir)",
116 | )
117 |
118 | args = parser.parse_args()
119 | args = validate_args(args)
120 |
121 | if args.gpu:
122 | providers = ["GPUExecutionProvider"]
123 | else:
124 | providers = ["CPUExecutionProvider"]
125 | model = ort.InferenceSession(args.model, providers=providers)
126 |
127 | model_inputs = model.get_inputs()
128 | model_outputs = list(model.get_outputs())
129 |
130 | if args.text:
131 | text_lines = args.text.splitlines()
132 | else:
133 | with open(args.file, encoding="utf-8") as file:
134 | text_lines = file.read().splitlines()
135 |
136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines]
137 | x = [line["x"].squeeze() for line in processed_lines]
138 | # Pad
139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True)
140 | x = x.detach().cpu().numpy()
141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64)
142 | inputs = {
143 | "x": x,
144 | "x_lengths": x_lengths,
145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32),
146 | }
147 | is_multi_speaker = len(model_inputs) == 4
148 | if is_multi_speaker:
149 | if args.spk is None:
150 | args.spk = 0
151 | warn = "[!] Speaker ID not provided! Using speaker ID 0"
152 | warnings.warn(warn, UserWarning)
153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64)
154 |
155 | has_vocoder_embedded = model_outputs[0].name == "wav"
156 | if has_vocoder_embedded:
157 | write_wavs(model, inputs, args.output_dir)
158 | elif args.vocoder:
159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers)
160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder)
161 | else:
162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory"
163 | warnings.warn(warn, UserWarning)
164 | write_mels(model, inputs, args.output_dir)
165 |
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/matcha/text/__init__.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 | from matcha.text import cleaners
3 | from matcha.text.symbols import symbols
4 |
5 | # Mappings from symbol to numeric ID and vice versa:
6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension
8 |
9 |
10 | def text_to_sequence(text, cleaner_names):
11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
12 | Args:
13 | text: string to convert to a sequence
14 | cleaner_names: names of the cleaner functions to run the text through
15 | Returns:
16 | List of integers corresponding to the symbols in the text
17 | """
18 | sequence = []
19 |
20 | clean_text = _clean_text(text, cleaner_names)
21 | for symbol in clean_text:
22 | symbol_id = _symbol_to_id[symbol]
23 | sequence += [symbol_id]
24 | return sequence
25 |
26 |
27 | def cleaned_text_to_sequence(cleaned_text):
28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
29 | Args:
30 | text: string to convert to a sequence
31 | Returns:
32 | List of integers corresponding to the symbols in the text
33 | """
34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text]
35 | return sequence
36 |
37 |
38 | def sequence_to_text(sequence):
39 | """Converts a sequence of IDs back to a string"""
40 | result = ""
41 | for symbol_id in sequence:
42 | s = _id_to_symbol[symbol_id]
43 | result += s
44 | return result
45 |
46 |
47 | def _clean_text(text, cleaner_names):
48 | for name in cleaner_names:
49 | cleaner = getattr(cleaners, name)
50 | if not cleaner:
51 | raise Exception("Unknown cleaner: %s" % name)
52 | text = cleaner(text)
53 | return text
54 |
--------------------------------------------------------------------------------
/matcha/text/cleaners.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron
2 |
3 | Cleaners are transformations that run over the input text at both training and eval time.
4 |
5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use:
7 | 1. "english_cleaners" for English text
8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode)
10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
11 | the symbols in symbols.py to match your data).
12 | """
13 |
14 | import logging
15 | import re
16 |
17 | import phonemizer
18 | import piper_phonemize
19 | from unidecode import unidecode
20 |
21 | # To avoid excessive logging we set the log level of the phonemizer package to Critical
22 | critical_logger = logging.getLogger("phonemizer")
23 | critical_logger.setLevel(logging.CRITICAL)
24 |
25 | # Intializing the phonemizer globally significantly reduces the speed
26 | # now the phonemizer is not initialising at every call
27 | # Might be less flexible, but it is much-much faster
28 | global_phonemizer = phonemizer.backend.EspeakBackend(
29 | language="en-us",
30 | preserve_punctuation=True,
31 | with_stress=True,
32 | language_switch="remove-flags",
33 | logger=critical_logger,
34 | )
35 |
36 |
37 | # Regular expression matching whitespace:
38 | _whitespace_re = re.compile(r"\s+")
39 |
40 | # List of (regular expression, replacement) pairs for abbreviations:
41 | _abbreviations = [
42 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
43 | for x in [
44 | ("mrs", "misess"),
45 | ("mr", "mister"),
46 | ("dr", "doctor"),
47 | ("st", "saint"),
48 | ("co", "company"),
49 | ("jr", "junior"),
50 | ("maj", "major"),
51 | ("gen", "general"),
52 | ("drs", "doctors"),
53 | ("rev", "reverend"),
54 | ("lt", "lieutenant"),
55 | ("hon", "honorable"),
56 | ("sgt", "sergeant"),
57 | ("capt", "captain"),
58 | ("esq", "esquire"),
59 | ("ltd", "limited"),
60 | ("col", "colonel"),
61 | ("ft", "fort"),
62 | ]
63 | ]
64 |
65 |
66 | def expand_abbreviations(text):
67 | for regex, replacement in _abbreviations:
68 | text = re.sub(regex, replacement, text)
69 | return text
70 |
71 |
72 | def lowercase(text):
73 | return text.lower()
74 |
75 |
76 | def collapse_whitespace(text):
77 | return re.sub(_whitespace_re, " ", text)
78 |
79 |
80 | def convert_to_ascii(text):
81 | return unidecode(text)
82 |
83 |
84 | def basic_cleaners(text):
85 | """Basic pipeline that lowercases and collapses whitespace without transliteration."""
86 | text = lowercase(text)
87 | text = collapse_whitespace(text)
88 | return text
89 |
90 |
91 | def transliteration_cleaners(text):
92 | """Pipeline for non-English text that transliterates to ASCII."""
93 | text = convert_to_ascii(text)
94 | text = lowercase(text)
95 | text = collapse_whitespace(text)
96 | return text
97 |
98 |
99 | def english_cleaners2(text):
100 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
101 | text = convert_to_ascii(text)
102 | text = lowercase(text)
103 | text = expand_abbreviations(text)
104 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0]
105 | phonemes = collapse_whitespace(phonemes)
106 | return phonemes
107 |
108 |
109 | def english_cleaners_piper(text):
110 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
111 | text = convert_to_ascii(text)
112 | text = lowercase(text)
113 | text = expand_abbreviations(text)
114 | phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0])
115 | phonemes = collapse_whitespace(phonemes)
116 | return phonemes
117 |
--------------------------------------------------------------------------------
/matcha/text/numbers.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron """
2 |
3 | import re
4 |
5 | import inflect
6 |
7 | _inflect = inflect.engine()
8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13 | _number_re = re.compile(r"[0-9]+")
14 |
15 |
16 | def _remove_commas(m):
17 | return m.group(1).replace(",", "")
18 |
19 |
20 | def _expand_decimal_point(m):
21 | return m.group(1).replace(".", " point ")
22 |
23 |
24 | def _expand_dollars(m):
25 | match = m.group(1)
26 | parts = match.split(".")
27 | if len(parts) > 2:
28 | return match + " dollars"
29 | dollars = int(parts[0]) if parts[0] else 0
30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
31 | if dollars and cents:
32 | dollar_unit = "dollar" if dollars == 1 else "dollars"
33 | cent_unit = "cent" if cents == 1 else "cents"
34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}"
35 | elif dollars:
36 | dollar_unit = "dollar" if dollars == 1 else "dollars"
37 | return f"{dollars} {dollar_unit}"
38 | elif cents:
39 | cent_unit = "cent" if cents == 1 else "cents"
40 | return f"{cents} {cent_unit}"
41 | else:
42 | return "zero dollars"
43 |
44 |
45 | def _expand_ordinal(m):
46 | return _inflect.number_to_words(m.group(0))
47 |
48 |
49 | def _expand_number(m):
50 | num = int(m.group(0))
51 | if num > 1000 and num < 3000:
52 | if num == 2000:
53 | return "two thousand"
54 | elif num > 2000 and num < 2010:
55 | return "two thousand " + _inflect.number_to_words(num % 100)
56 | elif num % 100 == 0:
57 | return _inflect.number_to_words(num // 100) + " hundred"
58 | else:
59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
60 | else:
61 | return _inflect.number_to_words(num, andword="")
62 |
63 |
64 | def normalize_numbers(text):
65 | text = re.sub(_comma_number_re, _remove_commas, text)
66 | text = re.sub(_pounds_re, r"\1 pounds", text)
67 | text = re.sub(_dollars_re, _expand_dollars, text)
68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text)
69 | text = re.sub(_ordinal_re, _expand_ordinal, text)
70 | text = re.sub(_number_re, _expand_number, text)
71 | return text
72 |
--------------------------------------------------------------------------------
/matcha/text/symbols.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/keithito/tacotron
2 |
3 | Defines the set of symbols used in text input to the model.
4 | """
5 | _pad = "_"
6 | _punctuation = ';:,.!?¡¿—…"«»“” '
7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
8 | _letters_ipa = (
9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
10 | )
11 |
12 |
13 | # Export all symbols:
14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
15 |
16 | # Special symbol ids
17 | SPACE_ID = symbols.index(" ")
18 |
--------------------------------------------------------------------------------
/matcha/train.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List, Optional, Tuple
2 |
3 | import hydra
4 | import lightning as L
5 | import rootutils
6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer
7 | from lightning.pytorch.loggers import Logger
8 | from omegaconf import DictConfig
9 |
10 | from matcha import utils
11 |
12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
13 | # ------------------------------------------------------------------------------------ #
14 | # the setup_root above is equivalent to:
15 | # - adding project root dir to PYTHONPATH
16 | # (so you don't need to force user to install project as a package)
17 | # (necessary before importing any local modules e.g. `from src import utils`)
18 | # - setting up PROJECT_ROOT environment variable
19 | # (which is used as a base for paths in "configs/paths/default.yaml")
20 | # (this way all filepaths are the same no matter where you run the code)
21 | # - loading environment variables from ".env" in root dir
22 | #
23 | # you can remove it if you:
24 | # 1. either install project as a package or move entry files to project root dir
25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml"
26 | #
27 | # more info: https://github.com/ashleve/rootutils
28 | # ------------------------------------------------------------------------------------ #
29 |
30 |
31 | log = utils.get_pylogger(__name__)
32 |
33 |
34 | @utils.task_wrapper
35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during
37 | training.
38 |
39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
40 | failure. Useful for multiruns, saving info about the crash, etc.
41 |
42 | :param cfg: A DictConfig configuration composed by Hydra.
43 | :return: A tuple with metrics and dict with all instantiated objects.
44 | """
45 | # set seed for random number generators in pytorch, numpy and python.random
46 | if cfg.get("seed"):
47 | L.seed_everything(cfg.seed, workers=True)
48 |
49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access
50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
51 |
52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access
53 | model: LightningModule = hydra.utils.instantiate(cfg.model)
54 |
55 | log.info("Instantiating callbacks...")
56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))
57 |
58 | log.info("Instantiating loggers...")
59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))
60 |
61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access
62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
63 |
64 | object_dict = {
65 | "cfg": cfg,
66 | "datamodule": datamodule,
67 | "model": model,
68 | "callbacks": callbacks,
69 | "logger": logger,
70 | "trainer": trainer,
71 | }
72 |
73 | if logger:
74 | log.info("Logging hyperparameters!")
75 | utils.log_hyperparameters(object_dict)
76 |
77 | if cfg.get("train"):
78 | log.info("Starting training!")
79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
80 |
81 | train_metrics = trainer.callback_metrics
82 |
83 | if cfg.get("test"):
84 | log.info("Starting testing!")
85 | ckpt_path = trainer.checkpoint_callback.best_model_path
86 | if ckpt_path == "":
87 | log.warning("Best ckpt not found! Using current weights for testing...")
88 | ckpt_path = None
89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
90 | log.info(f"Best ckpt path: {ckpt_path}")
91 |
92 | test_metrics = trainer.callback_metrics
93 |
94 | # merge train and test metrics
95 | metric_dict = {**train_metrics, **test_metrics}
96 |
97 | return metric_dict, object_dict
98 |
99 |
100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
101 | def main(cfg: DictConfig) -> Optional[float]:
102 | """Main entry point for training.
103 |
104 | :param cfg: DictConfig configuration composed by Hydra.
105 | :return: Optional[float] with optimized metric value.
106 | """
107 | # apply extra utilities
108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
109 | utils.extras(cfg)
110 |
111 | # train the model
112 | metric_dict, _ = train(cfg)
113 |
114 | # safely retrieve metric value for hydra-based hyperparameter optimization
115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric"))
116 |
117 | # return optimized metric
118 | return metric_value
119 |
120 |
121 | if __name__ == "__main__":
122 | main() # pylint: disable=no-value-for-parameter
123 |
--------------------------------------------------------------------------------
/matcha/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers
2 | from matcha.utils.logging_utils import log_hyperparameters
3 | from matcha.utils.pylogger import get_pylogger
4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree
5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper
6 |
--------------------------------------------------------------------------------
/matcha/utils/audio.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data
4 | from librosa.filters import mel as librosa_mel_fn
5 | from scipy.io.wavfile import read
6 |
7 | MAX_WAV_VALUE = 32768.0
8 |
9 |
10 | def load_wav(full_path):
11 | sampling_rate, data = read(full_path)
12 | return data, sampling_rate
13 |
14 |
15 | def dynamic_range_compression(x, C=1, clip_val=1e-5):
16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17 |
18 |
19 | def dynamic_range_decompression(x, C=1):
20 | return np.exp(x) / C
21 |
22 |
23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24 | return torch.log(torch.clamp(x, min=clip_val) * C)
25 |
26 |
27 | def dynamic_range_decompression_torch(x, C=1):
28 | return torch.exp(x) / C
29 |
30 |
31 | def spectral_normalize_torch(magnitudes):
32 | output = dynamic_range_compression_torch(magnitudes)
33 | return output
34 |
35 |
36 | def spectral_de_normalize_torch(magnitudes):
37 | output = dynamic_range_decompression_torch(magnitudes)
38 | return output
39 |
40 |
41 | mel_basis = {}
42 | hann_window = {}
43 |
44 |
45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46 | if torch.min(y) < -1.0:
47 | print("min value is ", torch.min(y))
48 | if torch.max(y) > 1.0:
49 | print("max value is ", torch.max(y))
50 |
51 | global mel_basis, hann_window # pylint: disable=global-statement
52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56 |
57 | y = torch.nn.functional.pad(
58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59 | )
60 | y = y.squeeze(1)
61 |
62 | spec = torch.view_as_real(
63 | torch.stft(
64 | y,
65 | n_fft,
66 | hop_length=hop_size,
67 | win_length=win_size,
68 | window=hann_window[str(y.device)],
69 | center=center,
70 | pad_mode="reflect",
71 | normalized=False,
72 | onesided=True,
73 | return_complex=True,
74 | )
75 | )
76 |
77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78 |
79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80 | spec = spectral_normalize_torch(spec)
81 |
82 | return spec
83 |
--------------------------------------------------------------------------------
/matcha/utils/generate_data_statistics.py:
--------------------------------------------------------------------------------
1 | r"""
2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it
3 | when needed.
4 |
5 | Parameters from hparam.py will be used
6 | """
7 | import argparse
8 | import json
9 | import os
10 | import sys
11 | from pathlib import Path
12 |
13 | import rootutils
14 | import torch
15 | from hydra import compose, initialize
16 | from omegaconf import open_dict
17 | from tqdm.auto import tqdm
18 |
19 | from matcha.data.text_mel_datamodule import TextMelDataModule
20 | from matcha.utils.logging_utils import pylogger
21 |
22 | log = pylogger.get_pylogger(__name__)
23 |
24 |
25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int):
26 | """Generate data mean and standard deviation helpful in data normalisation
27 |
28 | Args:
29 | data_loader (torch.utils.data.Dataloader): _description_
30 | out_channels (int): mel spectrogram channels
31 | """
32 | total_mel_sum = 0
33 | total_mel_sq_sum = 0
34 | total_mel_len = 0
35 |
36 | for batch in tqdm(data_loader, leave=False):
37 | mels = batch["y"]
38 | mel_lengths = batch["y_lengths"]
39 |
40 | total_mel_len += torch.sum(mel_lengths)
41 | total_mel_sum += torch.sum(mels)
42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2))
43 |
44 | data_mean = total_mel_sum / (total_mel_len * out_channels)
45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2))
46 |
47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()}
48 |
49 |
50 | def main():
51 | parser = argparse.ArgumentParser()
52 |
53 | parser.add_argument(
54 | "-i",
55 | "--input-config",
56 | type=str,
57 | default="vctk.yaml",
58 | help="The name of the yaml config file under configs/data",
59 | )
60 |
61 | parser.add_argument(
62 | "-b",
63 | "--batch-size",
64 | type=int,
65 | default="256",
66 | help="Can have increased batch size for faster computation",
67 | )
68 |
69 | parser.add_argument(
70 | "-f",
71 | "--force",
72 | action="store_true",
73 | default=False,
74 | required=False,
75 | help="force overwrite the file",
76 | )
77 | args = parser.parse_args()
78 | output_file = Path(args.input_config).with_suffix(".json")
79 |
80 | if os.path.exists(output_file) and not args.force:
81 | print("File already exists. Use -f to force overwrite")
82 | sys.exit(1)
83 |
84 | with initialize(version_base="1.3", config_path="../../configs/data"):
85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[])
86 |
87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root")
88 |
89 | with open_dict(cfg):
90 | del cfg["hydra"]
91 | del cfg["_target_"]
92 | cfg["data_statistics"] = None
93 | cfg["seed"] = 1234
94 | cfg["batch_size"] = args.batch_size
95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"]))
96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"]))
97 |
98 | text_mel_datamodule = TextMelDataModule(**cfg)
99 | text_mel_datamodule.setup()
100 | data_loader = text_mel_datamodule.train_dataloader()
101 | log.info("Dataloader loaded! Now computing stats...")
102 | params = compute_data_statistics(data_loader, cfg["n_feats"])
103 | print(params)
104 | json.dump(
105 | params,
106 | open(output_file, "w"),
107 | )
108 |
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/matcha/utils/instantiators.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import hydra
4 | from lightning import Callback
5 | from lightning.pytorch.loggers import Logger
6 | from omegaconf import DictConfig
7 |
8 | from matcha.utils import pylogger
9 |
10 | log = pylogger.get_pylogger(__name__)
11 |
12 |
13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
14 | """Instantiates callbacks from config.
15 |
16 | :param callbacks_cfg: A DictConfig object containing callback configurations.
17 | :return: A list of instantiated callbacks.
18 | """
19 | callbacks: List[Callback] = []
20 |
21 | if not callbacks_cfg:
22 | log.warning("No callback configs found! Skipping..")
23 | return callbacks
24 |
25 | if not isinstance(callbacks_cfg, DictConfig):
26 | raise TypeError("Callbacks config must be a DictConfig!")
27 |
28 | for _, cb_conf in callbacks_cfg.items():
29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access
31 | callbacks.append(hydra.utils.instantiate(cb_conf))
32 |
33 | return callbacks
34 |
35 |
36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
37 | """Instantiates loggers from config.
38 |
39 | :param logger_cfg: A DictConfig object containing logger configurations.
40 | :return: A list of instantiated loggers.
41 | """
42 | logger: List[Logger] = []
43 |
44 | if not logger_cfg:
45 | log.warning("No logger configs found! Skipping...")
46 | return logger
47 |
48 | if not isinstance(logger_cfg, DictConfig):
49 | raise TypeError("Logger config must be a DictConfig!")
50 |
51 | for _, lg_conf in logger_cfg.items():
52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access
54 | logger.append(hydra.utils.instantiate(lg_conf))
55 |
56 | return logger
57 |
--------------------------------------------------------------------------------
/matcha/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict
2 |
3 | from lightning.pytorch.utilities import rank_zero_only
4 | from omegaconf import OmegaConf
5 |
6 | from matcha.utils import pylogger
7 |
8 | log = pylogger.get_pylogger(__name__)
9 |
10 |
11 | @rank_zero_only
12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
13 | """Controls which config parts are saved by Lightning loggers.
14 |
15 | Additionally saves:
16 | - Number of model parameters
17 |
18 | :param object_dict: A dictionary containing the following objects:
19 | - `"cfg"`: A DictConfig object containing the main config.
20 | - `"model"`: The Lightning model.
21 | - `"trainer"`: The Lightning trainer.
22 | """
23 | hparams = {}
24 |
25 | cfg = OmegaConf.to_container(object_dict["cfg"])
26 | model = object_dict["model"]
27 | trainer = object_dict["trainer"]
28 |
29 | if not trainer.logger:
30 | log.warning("Logger not found! Skipping hyperparameter logging...")
31 | return
32 |
33 | hparams["model"] = cfg["model"]
34 |
35 | # save number of model parameters
36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad)
38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad)
39 |
40 | hparams["data"] = cfg["data"]
41 | hparams["trainer"] = cfg["trainer"]
42 |
43 | hparams["callbacks"] = cfg.get("callbacks")
44 | hparams["extras"] = cfg.get("extras")
45 |
46 | hparams["task_name"] = cfg.get("task_name")
47 | hparams["tags"] = cfg.get("tags")
48 | hparams["ckpt_path"] = cfg.get("ckpt_path")
49 | hparams["seed"] = cfg.get("seed")
50 |
51 | # send hparams to all loggers
52 | for logger in trainer.loggers:
53 | logger.log_hyperparams(hparams)
54 |
--------------------------------------------------------------------------------
/matcha/utils/model.py:
--------------------------------------------------------------------------------
1 | """ from https://github.com/jaywalnut310/glow-tts """
2 |
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def sequence_mask(length, max_length=None):
8 | if max_length is None:
9 | max_length = length.max()
10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device)
11 | return x.unsqueeze(0) < length.unsqueeze(1)
12 |
13 |
14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2):
15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet)
16 | length = (length / factor).ceil() * factor
17 | if not torch.onnx.is_in_onnx_export():
18 | return length.int().item()
19 | else:
20 | return length
21 |
22 |
23 | def convert_pad_shape(pad_shape):
24 | inverted_shape = pad_shape[::-1]
25 | pad_shape = [item for sublist in inverted_shape for item in sublist]
26 | return pad_shape
27 |
28 |
29 | def generate_path(duration, mask):
30 | device = duration.device
31 |
32 | b, t_x, t_y = mask.shape
33 | cum_duration = torch.cumsum(duration, 1)
34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
35 |
36 | cum_duration_flat = cum_duration.view(b * t_x)
37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
38 | path = path.view(b, t_x, t_y)
39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
40 | path = path * mask
41 | return path
42 |
43 |
44 | def duration_loss(logw, logw_, lengths):
45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
46 | return loss
47 |
48 |
49 | def normalize(data, mu, std):
50 | if not isinstance(mu, (float, int)):
51 | if isinstance(mu, list):
52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
53 | elif isinstance(mu, torch.Tensor):
54 | mu = mu.to(data.device)
55 | elif isinstance(mu, np.ndarray):
56 | mu = torch.from_numpy(mu).to(data.device)
57 | mu = mu.unsqueeze(-1)
58 |
59 | if not isinstance(std, (float, int)):
60 | if isinstance(std, list):
61 | std = torch.tensor(std, dtype=data.dtype, device=data.device)
62 | elif isinstance(std, torch.Tensor):
63 | std = std.to(data.device)
64 | elif isinstance(std, np.ndarray):
65 | std = torch.from_numpy(std).to(data.device)
66 | std = std.unsqueeze(-1)
67 |
68 | return (data - mu) / std
69 |
70 |
71 | def denormalize(data, mu, std):
72 | if not isinstance(mu, float):
73 | if isinstance(mu, list):
74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
75 | elif isinstance(mu, torch.Tensor):
76 | mu = mu.to(data.device)
77 | elif isinstance(mu, np.ndarray):
78 | mu = torch.from_numpy(mu).to(data.device)
79 | mu = mu.unsqueeze(-1)
80 |
81 | if not isinstance(std, float):
82 | if isinstance(std, list):
83 | std = torch.tensor(std, dtype=data.dtype, device=data.device)
84 | elif isinstance(std, torch.Tensor):
85 | std = std.to(data.device)
86 | elif isinstance(std, np.ndarray):
87 | std = torch.from_numpy(std).to(data.device)
88 | std = std.unsqueeze(-1)
89 |
90 | return data * std + mu
91 |
--------------------------------------------------------------------------------
/matcha/utils/monotonic_align/__init__.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from matcha.utils.monotonic_align.core import maximum_path_c
5 |
6 |
7 | def maximum_path(value, mask):
8 | """Cython optimised version.
9 | value: [b, t_x, t_y]
10 | mask: [b, t_x, t_y]
11 | """
12 | value = value * mask
13 | device = value.device
14 | dtype = value.dtype
15 | value = value.data.cpu().numpy().astype(np.float32)
16 | path = np.zeros_like(value).astype(np.int32)
17 | mask = mask.data.cpu().numpy()
18 |
19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32)
20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32)
21 | maximum_path_c(path, value, t_x_max, t_y_max)
22 | return torch.from_numpy(path).to(device=device, dtype=dtype)
23 |
--------------------------------------------------------------------------------
/matcha/utils/monotonic_align/core.pyx:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | cimport cython
4 | cimport numpy as np
5 |
6 | from cython.parallel import prange
7 |
8 |
9 | @cython.boundscheck(False)
10 | @cython.wraparound(False)
11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
12 | cdef int x
13 | cdef int y
14 | cdef float v_prev
15 | cdef float v_cur
16 | cdef float tmp
17 | cdef int index = t_x - 1
18 |
19 | for y in range(t_y):
20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
21 | if x == y:
22 | v_cur = max_neg_val
23 | else:
24 | v_cur = value[x, y-1]
25 | if x == 0:
26 | if y == 0:
27 | v_prev = 0.
28 | else:
29 | v_prev = max_neg_val
30 | else:
31 | v_prev = value[x-1, y-1]
32 | value[x, y] = max(v_cur, v_prev) + value[x, y]
33 |
34 | for y in range(t_y - 1, -1, -1):
35 | path[index, y] = 1
36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
37 | index = index - 1
38 |
39 |
40 | @cython.boundscheck(False)
41 | @cython.wraparound(False)
42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
43 | cdef int b = values.shape[0]
44 |
45 | cdef int i
46 | for i in prange(b, nogil=True):
47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
48 |
--------------------------------------------------------------------------------
/matcha/utils/monotonic_align/setup.py:
--------------------------------------------------------------------------------
1 | # from distutils.core import setup
2 | # from Cython.Build import cythonize
3 | # import numpy
4 |
5 | # setup(name='monotonic_align',
6 | # ext_modules=cythonize("core.pyx"),
7 | # include_dirs=[numpy.get_include()])
8 |
--------------------------------------------------------------------------------
/matcha/utils/pylogger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | from lightning.pytorch.utilities import rank_zero_only
4 |
5 |
6 | def get_pylogger(name: str = __name__) -> logging.Logger:
7 | """Initializes a multi-GPU-friendly python command line logger.
8 |
9 | :param name: The name of the logger, defaults to ``__name__``.
10 |
11 | :return: A logger object.
12 | """
13 | logger = logging.getLogger(name)
14 |
15 | # this ensures all logging levels get marked with the rank zero decorator
16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup
17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical")
18 | for level in logging_levels:
19 | setattr(logger, level, rank_zero_only(getattr(logger, level)))
20 |
21 | return logger
22 |
--------------------------------------------------------------------------------
/matcha/utils/rich_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Sequence
3 |
4 | import rich
5 | import rich.syntax
6 | import rich.tree
7 | from hydra.core.hydra_config import HydraConfig
8 | from lightning.pytorch.utilities import rank_zero_only
9 | from omegaconf import DictConfig, OmegaConf, open_dict
10 | from rich.prompt import Prompt
11 |
12 | from matcha.utils import pylogger
13 |
14 | log = pylogger.get_pylogger(__name__)
15 |
16 |
17 | @rank_zero_only
18 | def print_config_tree(
19 | cfg: DictConfig,
20 | print_order: Sequence[str] = (
21 | "data",
22 | "model",
23 | "callbacks",
24 | "logger",
25 | "trainer",
26 | "paths",
27 | "extras",
28 | ),
29 | resolve: bool = False,
30 | save_to_file: bool = False,
31 | ) -> None:
32 | """Prints the contents of a DictConfig as a tree structure using the Rich library.
33 |
34 | :param cfg: A DictConfig composed by Hydra.
35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model",
36 | "callbacks", "logger", "trainer", "paths", "extras")``.
37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
39 | """
40 | style = "dim"
41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
42 |
43 | queue = []
44 |
45 | # add fields from `print_order` to queue
46 | for field in print_order:
47 | _ = (
48 | queue.append(field)
49 | if field in cfg
50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...")
51 | )
52 |
53 | # add all the other fields to queue (not specified in `print_order`)
54 | for field in cfg:
55 | if field not in queue:
56 | queue.append(field)
57 |
58 | # generate config tree from queue
59 | for field in queue:
60 | branch = tree.add(field, style=style, guide_style=style)
61 |
62 | config_group = cfg[field]
63 | if isinstance(config_group, DictConfig):
64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
65 | else:
66 | branch_content = str(config_group)
67 |
68 | branch.add(rich.syntax.Syntax(branch_content, "yaml"))
69 |
70 | # print config tree
71 | rich.print(tree)
72 |
73 | # save config tree to file
74 | if save_to_file:
75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
76 | rich.print(tree, file=file)
77 |
78 |
79 | @rank_zero_only
80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
81 | """Prompts user to input tags from command line if no tags are provided in config.
82 |
83 | :param cfg: A DictConfig composed by Hydra.
84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
85 | """
86 | if not cfg.get("tags"):
87 | if "id" in HydraConfig().cfg.hydra.job:
88 | raise ValueError("Specify tags before launching a multirun!")
89 |
90 | log.warning("No tags provided in config. Prompting user to input tags...")
91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
92 | tags = [t.strip() for t in tags.split(",") if t != ""]
93 |
94 | with open_dict(cfg):
95 | cfg.tags = tags
96 |
97 | log.info(f"Tags: {cfg.tags}")
98 |
99 | if save_to_file:
100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
101 | rich.print(cfg.tags, file=file)
102 |
--------------------------------------------------------------------------------
/matcha/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import warnings
4 | from importlib.util import find_spec
5 | from pathlib import Path
6 | from typing import Any, Callable, Dict, Tuple
7 |
8 | import gdown
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | import torch
12 | import wget
13 | from omegaconf import DictConfig
14 |
15 | from matcha.utils import pylogger, rich_utils
16 |
17 | log = pylogger.get_pylogger(__name__)
18 |
19 |
20 | def extras(cfg: DictConfig) -> None:
21 | """Applies optional utilities before the task is started.
22 |
23 | Utilities:
24 | - Ignoring python warnings
25 | - Setting tags from command line
26 | - Rich config printing
27 |
28 | :param cfg: A DictConfig object containing the config tree.
29 | """
30 | # return if no `extras` config
31 | if not cfg.get("extras"):
32 | log.warning("Extras config not found! ")
33 | return
34 |
35 | # disable python warnings
36 | if cfg.extras.get("ignore_warnings"):
37 | log.info("Disabling python warnings! ")
38 | warnings.filterwarnings("ignore")
39 |
40 | # prompt user to input tags from command line if none are provided in the config
41 | if cfg.extras.get("enforce_tags"):
42 | log.info("Enforcing tags! ")
43 | rich_utils.enforce_tags(cfg, save_to_file=True)
44 |
45 | # pretty print config tree using Rich library
46 | if cfg.extras.get("print_config"):
47 | log.info("Printing config tree with Rich! ")
48 | rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
49 |
50 |
51 | def task_wrapper(task_func: Callable) -> Callable:
52 | """Optional decorator that controls the failure behavior when executing the task function.
53 |
54 | This wrapper can be used to:
55 | - make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
56 | - save the exception to a `.log` file
57 | - mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
58 | - etc. (adjust depending on your needs)
59 |
60 | Example:
61 | ```
62 | @utils.task_wrapper
63 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
64 | ...
65 | return metric_dict, object_dict
66 | ```
67 |
68 | :param task_func: The task function to be wrapped.
69 |
70 | :return: The wrapped task function.
71 | """
72 |
73 | def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
74 | # execute the task
75 | try:
76 | metric_dict, object_dict = task_func(cfg=cfg)
77 |
78 | # things to do if exception occurs
79 | except Exception as ex:
80 | # save exception to `.log` file
81 | log.exception("")
82 |
83 | # some hyperparameter combinations might be invalid or cause out-of-memory errors
84 | # so when using hparam search plugins like Optuna, you might want to disable
85 | # raising the below exception to avoid multirun failure
86 | raise ex
87 |
88 | # things to always do after either success or exception
89 | finally:
90 | # display output dir path in terminal
91 | log.info(f"Output dir: {cfg.paths.output_dir}")
92 |
93 | # always close wandb run (even if exception occurs so multirun won't fail)
94 | if find_spec("wandb"): # check if wandb is installed
95 | import wandb
96 |
97 | if wandb.run:
98 | log.info("Closing wandb!")
99 | wandb.finish()
100 |
101 | return metric_dict, object_dict
102 |
103 | return wrap
104 |
105 |
106 | def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float:
107 | """Safely retrieves value of the metric logged in LightningModule.
108 |
109 | :param metric_dict: A dict containing metric values.
110 | :param metric_name: The name of the metric to retrieve.
111 | :return: The value of the metric.
112 | """
113 | if not metric_name:
114 | log.info("Metric name is None! Skipping metric value retrieval...")
115 | return None
116 |
117 | if metric_name not in metric_dict:
118 | raise ValueError(
119 | f"Metric value not found! \n"
120 | "Make sure metric name logged in LightningModule is correct!\n"
121 | "Make sure `optimized_metric` name in `hparams_search` config is correct!"
122 | )
123 |
124 | metric_value = metric_dict[metric_name].item()
125 | log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
126 |
127 | return metric_value
128 |
129 |
130 | def intersperse(lst, item):
131 | # Adds blank symbol
132 | result = [item] * (len(lst) * 2 + 1)
133 | result[1::2] = lst
134 | return result
135 |
136 |
137 | def save_figure_to_numpy(fig):
138 | data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
139 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
140 | return data
141 |
142 |
143 | def plot_tensor(tensor):
144 | plt.style.use("default")
145 | fig, ax = plt.subplots(figsize=(12, 3))
146 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
147 | plt.colorbar(im, ax=ax)
148 | plt.tight_layout()
149 | fig.canvas.draw()
150 | data = save_figure_to_numpy(fig)
151 | plt.close()
152 | return data
153 |
154 |
155 | def save_plot(tensor, savepath):
156 | plt.style.use("default")
157 | fig, ax = plt.subplots(figsize=(12, 3))
158 | im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none")
159 | plt.colorbar(im, ax=ax)
160 | plt.tight_layout()
161 | fig.canvas.draw()
162 | plt.savefig(savepath)
163 | plt.close()
164 |
165 |
166 | def to_numpy(tensor):
167 | if isinstance(tensor, np.ndarray):
168 | return tensor
169 | elif isinstance(tensor, torch.Tensor):
170 | return tensor.detach().cpu().numpy()
171 | elif isinstance(tensor, list):
172 | return np.array(tensor)
173 | else:
174 | raise TypeError("Unsupported type for conversion to numpy array")
175 |
176 |
177 | def get_user_data_dir(appname="matcha_tts"):
178 | """
179 | Args:
180 | appname (str): Name of application
181 |
182 | Returns:
183 | Path: path to user data directory
184 | """
185 |
186 | MATCHA_HOME = os.environ.get("MATCHA_HOME")
187 | if MATCHA_HOME is not None:
188 | ans = Path(MATCHA_HOME).expanduser().resolve(strict=False)
189 | elif sys.platform == "win32":
190 | import winreg # pylint: disable=import-outside-toplevel
191 |
192 | key = winreg.OpenKey(
193 | winreg.HKEY_CURRENT_USER,
194 | r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders",
195 | )
196 | dir_, _ = winreg.QueryValueEx(key, "Local AppData")
197 | ans = Path(dir_).resolve(strict=False)
198 | elif sys.platform == "darwin":
199 | ans = Path("~/Library/Application Support/").expanduser()
200 | else:
201 | ans = Path.home().joinpath(".local/share")
202 |
203 | final_path = ans.joinpath(appname)
204 | final_path.mkdir(parents=True, exist_ok=True)
205 | return final_path
206 |
207 |
208 | def assert_model_downloaded(checkpoint_path, url, use_wget=True):
209 | if Path(checkpoint_path).exists():
210 | log.debug(f"[+] Model already present at {checkpoint_path}!")
211 | print(f"[+] Model already present at {checkpoint_path}!")
212 | return
213 | log.info(f"[-] Model not found at {checkpoint_path}! Will download it")
214 | print(f"[-] Model not found at {checkpoint_path}! Will download it")
215 | checkpoint_path = str(checkpoint_path)
216 | if not use_wget:
217 | gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True)
218 | else:
219 | wget.download(url=url, out=checkpoint_path)
220 |
--------------------------------------------------------------------------------
/nodes/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | now_dir = os.path.dirname(os.path.abspath(__file__))
4 | node_root = os.path.dirname(now_dir)
5 | sys.path.append(node_root)
--------------------------------------------------------------------------------
/nodes/sensevoice_nodes.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: SpenserCai
3 | Date: 2024-10-04 12:13:43
4 | version:
5 | LastEditors: SpenserCai
6 | LastEditTime: 2024-10-05 11:03:31
7 | Description: file content
8 | '''
9 | import folder_paths
10 | import os
11 | import numpy as np
12 | from funasr import AutoModel
13 | from funaudio_utils.pre import FunAudioLLMTool
14 | from funaudio_utils.download_models import download_sensevoice_small
15 | from funasr.utils import postprocess_utils
16 |
17 | fAudioTool = FunAudioLLMTool()
18 |
19 | CATEGORY_NAME = "FunAudioLLM - SenseVoice"
20 |
21 | folder_paths.add_model_folder_path("SenseVoice", os.path.join(folder_paths.models_dir, "SenseVoice"))
22 |
23 | def patch_emoji(emoji_dict):
24 | t_emoji_dict_key = emoji_dict.keys()
25 | emoji_dict_new = {}
26 | for t_e_k in t_emoji_dict_key:
27 | emoji_dict_new[t_e_k.lower()] = emoji_dict[t_e_k]
28 | emoji_dict.update(emoji_dict_new)
29 | return emoji_dict
30 |
31 | class SenseVoiceNode:
32 | @classmethod
33 | def INPUT_TYPES(s):
34 | return {
35 | "required":{
36 | "audio":("AUDIO",),
37 | "use_fast_mode":("BOOLEAN",{
38 | "default": False
39 | }),
40 | "punc_segment":("BOOLEAN",{
41 | "default": False
42 | }),
43 | }
44 | }
45 |
46 | CATEGORY = CATEGORY_NAME
47 | RETURN_TYPES = ("STRING",)
48 |
49 | FUNCTION="generate"
50 |
51 | def generate(self,audio, use_fast_mode,punc_segment):
52 | sensevoice_code_path = os.path.join(folder_paths.base_path,"custom_nodes/ComfyUI-FunAudioLLM/sensevoice/model.py")
53 | speech = audio["waveform"]
54 | source_sr = audio["sample_rate"]
55 | speech = fAudioTool.audio_resample(speech, source_sr)
56 | speech = fAudioTool.postprocess(speech)
57 | # 判断语音长度是否大于30s
58 | if speech.shape[1] > 30 * 22050 and use_fast_mode:
59 | raise ValueError("Audio length is too long, please set use_fast_mode to False.")
60 | _, model_dir = download_sensevoice_small()
61 | model_arg = {
62 | "input":speech[0],
63 | "cache":{},
64 | "language":"auto",
65 | "batch_size_s":60,
66 | }
67 | model_use_arg = {
68 | "model":model_dir,
69 | "trust_remote_code":True,
70 | "remote_code":sensevoice_code_path,
71 | "device":"cuda:0",
72 | }
73 |
74 | if not use_fast_mode:
75 | model_use_arg["vad_model"] = "fsmn-vad"
76 | model_use_arg["vad_kwargs"] = {"max_single_segment_time":30000}
77 |
78 | model_arg["merge_vad"] = True
79 | model_arg["merge_length_s"] = 15
80 |
81 | if punc_segment:
82 | model_use_arg["punc_model"] = "ct-punc-c"
83 |
84 | model = AutoModel(**model_use_arg)
85 | output = model.generate(**model_arg)
86 | postprocess_utils.emoji_dict = patch_emoji(postprocess_utils.emoji_dict)
87 | return (postprocess_utils.rich_transcription_postprocess(output[0]["text"]),)
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comfyui-funaudiollm"
3 | description = "Comfyui custom node for [FunAudioLLM](https://funaudiollm.github.io/) include [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) and [SenseVoice](https://github.com/FunAudioLLM/SenseVoice)."
4 | version = "1.0.0"
5 | license = {file = "LICENSE"}
6 | dependencies = ["conformer", "deepspeed; sys_platform == 'linux'", "diffusers", "grpcio", "grpcio-tools", "hydra-core", "HyperPyYAML", "inflect", "librosa", "lightning", "matplotlib", "modelscope", "networkx", "omegaconf", "onnxruntime-gpu; sys_platform == 'linux'", "onnxruntime; sys_platform == 'darwin' or sys_platform == 'win32'", "openai-whisper", "protobuf", "pydantic", "rich", "soundfile", "tensorboard", "wget", "gdown", "pyarrow", "jieba", "pypinyin", "pydub", "audiosegment", "srt", "ffmpeg-python", "WeTextProcessing", "aliyun-python-sdk-core", "funasr>=1.1.3", "huggingface", "huggingface_hub"]
7 |
8 | [project.urls]
9 | Repository = "https://github.com/SpenserCai/ComfyUI-FunAudioLLM"
10 | # Used by Comfy Registry https://comfyregistry.org
11 |
12 | [tool.comfy]
13 | PublisherId = "spensercai"
14 | DisplayName = "ComfyUI-FunAudioLLM"
15 | Icon = ""
16 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | conformer
2 | deepspeed; sys_platform == 'linux'
3 | diffusers
4 | grpcio
5 | grpcio-tools
6 | hydra-core
7 | HyperPyYAML
8 | inflect
9 | librosa
10 | lightning
11 | matplotlib
12 | modelscope
13 | networkx
14 | omegaconf
15 | onnxruntime-gpu; sys_platform == 'linux'
16 | onnxruntime; sys_platform == 'darwin' or sys_platform == 'win32'
17 | openai-whisper
18 | protobuf
19 | pydantic
20 | rich
21 | soundfile
22 | tensorboard
23 | wget
24 | gdown
25 | pyarrow
26 | jieba
27 | pypinyin
28 | pydub
29 | audiosegment
30 | srt
31 | ffmpeg-python
32 | WeTextProcessing
33 | funasr>=1.1.3
34 | huggingface
35 | huggingface_hub
36 |
--------------------------------------------------------------------------------
/sensevoice/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/sensevoice/__init__.py
--------------------------------------------------------------------------------
/sensevoice/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/sensevoice/utils/__init__.py
--------------------------------------------------------------------------------
/sensevoice/utils/export_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 |
5 | def export(
6 | model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
7 | ):
8 | model_scripts = model.export(**kwargs)
9 | export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
10 | os.makedirs(export_dir, exist_ok=True)
11 |
12 | if not isinstance(model_scripts, (list, tuple)):
13 | model_scripts = (model_scripts,)
14 | for m in model_scripts:
15 | m.eval()
16 | if type == "onnx":
17 | _onnx(
18 | m,
19 | quantize=quantize,
20 | opset_version=opset_version,
21 | export_dir=export_dir,
22 | **kwargs,
23 | )
24 | print("output dir: {}".format(export_dir))
25 |
26 | return export_dir
27 |
28 |
29 | def _onnx(
30 | model,
31 | quantize: bool = False,
32 | opset_version: int = 14,
33 | export_dir: str = None,
34 | **kwargs,
35 | ):
36 |
37 | dummy_input = model.export_dummy_inputs()
38 |
39 | verbose = kwargs.get("verbose", False)
40 |
41 | export_name = model.export_name()
42 | model_path = os.path.join(export_dir, export_name)
43 | torch.onnx.export(
44 | model,
45 | dummy_input,
46 | model_path,
47 | verbose=verbose,
48 | opset_version=opset_version,
49 | input_names=model.export_input_names(),
50 | output_names=model.export_output_names(),
51 | dynamic_axes=model.export_dynamic_axes(),
52 | )
53 |
54 | if quantize:
55 | from onnxruntime.quantization import QuantType, quantize_dynamic
56 | import onnx
57 |
58 | quant_model_path = model_path.replace(".onnx", "_quant.onnx")
59 | if not os.path.exists(quant_model_path):
60 | onnx_model = onnx.load(model_path)
61 | nodes = [n.name for n in onnx_model.graph.node]
62 | nodes_to_exclude = [
63 | m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
64 | ]
65 | quantize_dynamic(
66 | model_input=model_path,
67 | model_output=quant_model_path,
68 | op_types_to_quantize=["MatMul"],
69 | per_channel=True,
70 | reduce_range=False,
71 | weight_type=QuantType.QUInt8,
72 | nodes_to_exclude=nodes_to_exclude,
73 | )
74 |
--------------------------------------------------------------------------------
/sensevoice/utils/model_bin.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- encoding: utf-8 -*-
3 | # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
4 | # MIT License (https://opensource.org/licenses/MIT)
5 |
6 | import os.path
7 | from pathlib import Path
8 | from typing import List, Union, Tuple
9 | import torch
10 | import librosa
11 | import numpy as np
12 |
13 | from utils.infer_utils import (
14 | CharTokenizer,
15 | Hypothesis,
16 | ONNXRuntimeError,
17 | OrtInferSession,
18 | TokenIDConverter,
19 | get_logger,
20 | read_yaml,
21 | )
22 | from utils.frontend import WavFrontend
23 | from utils.infer_utils import pad_list
24 |
25 | logging = get_logger()
26 |
27 |
28 | class SenseVoiceSmallONNX:
29 | """
30 | Author: Speech Lab of DAMO Academy, Alibaba Group
31 | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
32 | https://arxiv.org/abs/2206.08317
33 | """
34 |
35 | def __init__(
36 | self,
37 | model_dir: Union[str, Path] = None,
38 | batch_size: int = 1,
39 | device_id: Union[str, int] = "-1",
40 | plot_timestamp_to: str = "",
41 | quantize: bool = False,
42 | intra_op_num_threads: int = 4,
43 | cache_dir: str = None,
44 | **kwargs,
45 | ):
46 | if quantize:
47 | model_file = os.path.join(model_dir, "model_quant.onnx")
48 | else:
49 | model_file = os.path.join(model_dir, "model.onnx")
50 |
51 | config_file = os.path.join(model_dir, "config.yaml")
52 | cmvn_file = os.path.join(model_dir, "am.mvn")
53 | config = read_yaml(config_file)
54 | # token_list = os.path.join(model_dir, "tokens.json")
55 | # with open(token_list, "r", encoding="utf-8") as f:
56 | # token_list = json.load(f)
57 |
58 | # self.converter = TokenIDConverter(token_list)
59 | self.tokenizer = CharTokenizer()
60 | config["frontend_conf"]['cmvn_file'] = cmvn_file
61 | self.frontend = WavFrontend(**config["frontend_conf"])
62 | self.ort_infer = OrtInferSession(
63 | model_file, device_id, intra_op_num_threads=intra_op_num_threads
64 | )
65 | self.batch_size = batch_size
66 | self.blank_id = 0
67 |
68 | def __call__(self,
69 | wav_content: Union[str, np.ndarray, List[str]],
70 | language: List,
71 | textnorm: List,
72 | tokenizer=None,
73 | **kwargs) -> List:
74 | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
75 | waveform_nums = len(waveform_list)
76 | asr_res = []
77 | for beg_idx in range(0, waveform_nums, self.batch_size):
78 | end_idx = min(waveform_nums, beg_idx + self.batch_size)
79 | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
80 | ctc_logits, encoder_out_lens = self.infer(feats,
81 | feats_len,
82 | np.array(language, dtype=np.int32),
83 | np.array(textnorm, dtype=np.int32)
84 | )
85 | # back to torch.Tensor
86 | ctc_logits = torch.from_numpy(ctc_logits).float()
87 | # support batch_size=1 only currently
88 | x = ctc_logits[0, : encoder_out_lens[0].item(), :]
89 | yseq = x.argmax(dim=-1)
90 | yseq = torch.unique_consecutive(yseq, dim=-1)
91 |
92 | mask = yseq != self.blank_id
93 | token_int = yseq[mask].tolist()
94 |
95 | if tokenizer is not None:
96 | asr_res.append(tokenizer.tokens2text(token_int))
97 | else:
98 | asr_res.append(token_int)
99 | return asr_res
100 |
101 | def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
102 | def load_wav(path: str) -> np.ndarray:
103 | waveform, _ = librosa.load(path, sr=fs)
104 | return waveform
105 |
106 | if isinstance(wav_content, np.ndarray):
107 | return [wav_content]
108 |
109 | if isinstance(wav_content, str):
110 | return [load_wav(wav_content)]
111 |
112 | if isinstance(wav_content, list):
113 | return [load_wav(path) for path in wav_content]
114 |
115 | raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
116 |
117 | def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
118 | feats, feats_len = [], []
119 | for waveform in waveform_list:
120 | speech, _ = self.frontend.fbank(waveform)
121 | feat, feat_len = self.frontend.lfr_cmvn(speech)
122 | feats.append(feat)
123 | feats_len.append(feat_len)
124 |
125 | feats = self.pad_feats(feats, np.max(feats_len))
126 | feats_len = np.array(feats_len).astype(np.int32)
127 | return feats, feats_len
128 |
129 | @staticmethod
130 | def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
131 | def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
132 | pad_width = ((0, max_feat_len - cur_len), (0, 0))
133 | return np.pad(feat, pad_width, "constant", constant_values=0)
134 |
135 | feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
136 | feats = np.array(feat_res).astype(np.float32)
137 | return feats
138 |
139 | def infer(self,
140 | feats: np.ndarray,
141 | feats_len: np.ndarray,
142 | language: np.ndarray,
143 | textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
144 | outputs = self.ort_infer([feats, feats_len, language, textnorm])
145 | return outputs
146 |
--------------------------------------------------------------------------------
/web/PUT_WEB_JS_HERE:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpenserCai/ComfyUI-FunAudioLLM/d35cded867a2f55d3c56b314aae53997a3d68367/web/PUT_WEB_JS_HERE
--------------------------------------------------------------------------------