├── .github └── workflows │ └── publish.yml ├── .gitignore ├── README.md ├── __init__.py ├── cosyvoice ├── __init__.py ├── bin │ ├── __init__.py │ ├── average_model.py │ ├── export_jit.py │ ├── export_onnx.py │ ├── export_trt.sh │ ├── inference.py │ └── train.py ├── cli │ ├── __init__.py │ ├── cosyvoice.py │ ├── frontend.py │ └── model.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ └── processor.py ├── flow │ ├── __init__.py │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── hifigan │ ├── __init__.py │ ├── discriminator.py │ ├── f0_predictor.py │ ├── generator.py │ └── hifigan.py ├── llm │ ├── __init__.py │ └── llm.py ├── tokenizer │ ├── __init__.py │ ├── assets │ │ └── multilingual_zh_ja_yue_char_del.tiktoken │ └── tokenizer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── subsampling.py │ └── upsample_encoder.py └── utils │ ├── __init__.py │ ├── class_utils.py │ ├── common.py │ ├── executor.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── losses.py │ ├── mask.py │ ├── scheduler.py │ └── train_utils.py ├── downloadmodel.py ├── examples ├── CrossLingual.json ├── Instruct2.json └── ZeroShot.json ├── pyproject.toml ├── requirements.txt └── third_party ├── Matcha-TTS ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── __init__.py ├── configs │ ├── __init__.py │ ├── callbacks │ │ ├── default.yaml │ │ ├── model_checkpoint.yaml │ │ ├── model_summary.yaml │ │ ├── none.yaml │ │ └── rich_progress_bar.yaml │ ├── data │ │ ├── hi-fi_en-US_female.yaml │ │ ├── ljspeech.yaml │ │ └── vctk.yaml │ ├── debug │ │ ├── default.yaml │ │ ├── fdr.yaml │ │ ├── limit.yaml │ │ ├── overfit.yaml │ │ └── profiler.yaml │ ├── eval.yaml │ ├── experiment │ │ ├── hifi_dataset_piper_phonemizer.yaml │ │ ├── ljspeech.yaml │ │ ├── ljspeech_min_memory.yaml │ │ └── multispeaker.yaml │ ├── extras │ │ └── default.yaml │ ├── hparams_search │ │ └── mnist_optuna.yaml │ ├── hydra │ │ └── default.yaml │ ├── local │ │ └── .gitkeep │ ├── logger │ │ ├── aim.yaml │ │ ├── comet.yaml │ │ ├── csv.yaml │ │ ├── many_loggers.yaml │ │ ├── mlflow.yaml │ │ ├── neptune.yaml │ │ ├── tensorboard.yaml │ │ └── wandb.yaml │ ├── model │ │ ├── cfm │ │ │ └── default.yaml │ │ ├── decoder │ │ │ └── default.yaml │ │ ├── encoder │ │ │ └── default.yaml │ │ ├── matcha.yaml │ │ └── optimizer │ │ │ └── adam.yaml │ ├── paths │ │ └── default.yaml │ ├── train.yaml │ └── trainer │ │ ├── cpu.yaml │ │ ├── ddp.yaml │ │ ├── ddp_sim.yaml │ │ ├── default.yaml │ │ ├── gpu.yaml │ │ └── mps.yaml ├── data ├── matcha │ ├── VERSION │ ├── __init__.py │ ├── app.py │ ├── cli.py │ ├── data │ │ ├── __init__.py │ │ ├── components │ │ │ └── __init__.py │ │ └── text_mel_datamodule.py │ ├── hifigan │ │ ├── LICENSE │ │ ├── README.md │ │ ├── __init__.py │ │ ├── config.py │ │ ├── denoiser.py │ │ ├── env.py │ │ ├── meldataset.py │ │ ├── models.py │ │ └── xutils.py │ ├── models │ │ ├── __init__.py │ │ ├── baselightningmodule.py │ │ ├── components │ │ │ ├── __init__.py │ │ │ ├── decoder.py │ │ │ ├── flow_matching.py │ │ │ ├── text_encoder.py │ │ │ └── transformer.py │ │ └── matcha_tts.py │ ├── onnx │ │ ├── __init__.py │ │ ├── export.py │ │ └── infer.py │ ├── text │ │ ├── __init__.py │ │ ├── cleaners.py │ │ ├── numbers.py │ │ └── symbols.py │ ├── train.py │ └── utils │ │ ├── __init__.py │ │ ├── audio.py │ │ ├── generate_data_statistics.py │ │ ├── instantiators.py │ │ ├── logging_utils.py │ │ ├── model.py │ │ ├── monotonic_align │ │ ├── __init__.py │ │ ├── core.pyx │ │ └── setup.py │ │ ├── pylogger.py │ │ ├── rich_utils.py │ │ └── utils.py ├── notebooks │ └── .gitkeep ├── pyproject.toml ├── requirements.txt ├── scripts │ └── schedule.sh ├── setup.py └── synthesis.ipynb └── __init__.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | jobs: 12 | publish-node: 13 | name: Publish Custom Node to registry 14 | runs-on: ubuntu-latest 15 | if: ${{ github.repository_owner == 'muxueChen' }} 16 | steps: 17 | - name: Check out code 18 | uses: actions/checkout@v4 19 | with: 20 | submodules: true 21 | - name: Publish Custom Node 22 | uses: Comfy-Org/publish-node-action@main 23 | with: 24 | ## Add your own personal access token to your Github Repository secrets and reference it here. 25 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 26 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /idea 2 | /__pycache__ 3 | **/__pycache__ 4 | **/**/__pycache__ 5 | /pretrained_models 6 | .vscode 7 | .vs 8 | .idea 9 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CosyVoice2 for ComfyUI 2 | ComfyUI_NTCosyVoice is a plugin of ComfyUI for Cosysvoice2 3 | ## install plugin 4 | ```angular2html 5 | git clone https://github.com/muxueChen/ComfyUI_NTCosyVoice.git 6 | ``` 7 | ## Install dependency packages 8 | ```angular2html 9 | cd ComfyUI_NTCosyVoice 10 | pip install -r requirements.txt 11 | ``` 12 | ## download models 13 | ```angular2html 14 | python downloadmodel.py 15 | ``` 16 | ## Install ttsfrd (Optional) 17 | Notice that this step is not necessary. If you do not install ttsfrd package, we will use WeTextProcessing by default. 18 | ```angular2html 19 | cd pretrained_models/CosyVoice-ttsfrd/ 20 | unzip resource.zip -d . 21 | pip install ttsfrd_dependency-0.1-py3-none-any.whl 22 | pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl 23 | ``` -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | nor_dir = os.path.dirname(__file__) 4 | Matcha_path = os.path.join(nor_dir, 'third_party/Matcha-TTS') 5 | sys.path.append(nor_dir) 6 | sys.path.append(Matcha_path) 7 | 8 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 9 | from cosyvoice.utils.file_utils import load_wav 10 | import torchaudio 11 | import torch 12 | 13 | 14 | def nt_load_wav(speech, sample_rate, target_sr): 15 | speech = speech.mean(dim=0, keepdim=True) 16 | if sample_rate != target_sr: 17 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 18 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 19 | return speech 20 | 21 | 22 | class NTCosyVoiceZeroShotSampler: 23 | def __init__(self): 24 | self.__cosyvoice = None 25 | 26 | @property 27 | def cosyvoice(self): 28 | if self.__cosyvoice is None: 29 | model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B') 30 | self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False) 31 | return self.__cosyvoice 32 | 33 | @classmethod 34 | def INPUT_TYPES(s): 35 | return { 36 | "required": { 37 | "audio": ("AUDIO",), 38 | "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}), 39 | "text": ("STRING", {"multiline": True}), 40 | "prompt_text": ("STRING", {"multiline": True}), 41 | }, 42 | } 43 | 44 | RETURN_TYPES = ("AUDIO",) 45 | RETURN_NAMES = ("tts_speech",) 46 | FUNCTION = "main_func" 47 | CATEGORY = "Nineton Nodes" 48 | 49 | def main_func(self, audio, speed, text, prompt_text): 50 | waveform = audio["waveform"].squeeze(0) 51 | sample_rate = audio["sample_rate"] 52 | print(f"waveform:{waveform}, sample_rate:{sample_rate}") 53 | prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000) 54 | speechs = [] 55 | for i, j in enumerate(self.cosyvoice.inference_zero_shot(tts_text=text, prompt_text=prompt_text, prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)): 56 | speechs.append(j['tts_speech']) 57 | 58 | tts_speech = torch.cat(speechs, dim=1) 59 | tts_speech = tts_speech.unsqueeze(0) 60 | outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate} 61 | 62 | return (outaudio,) 63 | 64 | 65 | class NTCosyVoiceCrossLingualSampler: 66 | def __init__(self): 67 | self.__cosyvoice = None 68 | 69 | @property 70 | def cosyvoice(self): 71 | if self.__cosyvoice is None: 72 | model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B') 73 | self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False) 74 | return self.__cosyvoice 75 | 76 | @classmethod 77 | def INPUT_TYPES(s): 78 | return { 79 | "required": { 80 | "audio": ("AUDIO",), 81 | "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}), 82 | "text": ("STRING", {"multiline": True}), 83 | }, 84 | } 85 | 86 | RETURN_TYPES = ("AUDIO",) 87 | RETURN_NAMES = ("tts_speech",) 88 | FUNCTION = "main_func" 89 | CATEGORY = "Nineton Nodes" 90 | 91 | def main_func(self, audio, speed, text): 92 | waveform = audio["waveform"].squeeze(0) 93 | sample_rate = audio["sample_rate"] 94 | 95 | prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000) 96 | speechs = [] 97 | for i, j in enumerate(self.cosyvoice.inference_cross_lingual(tts_text=text, 98 | prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)): 99 | speechs.append(j['tts_speech']) 100 | 101 | tts_speech = torch.cat(speechs, dim=1) 102 | tts_speech = tts_speech.unsqueeze(0) 103 | outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate} 104 | 105 | return (outaudio,) 106 | 107 | 108 | class NTCosyVoiceInstruct2Sampler: 109 | def __init__(self): 110 | self.__cosyvoice = None 111 | 112 | @property 113 | def cosyvoice(self): 114 | if self.__cosyvoice is None: 115 | model_path = os.path.join(nor_dir, 'pretrained_models/CosyVoice2-0.5B') 116 | self.__cosyvoice = CosyVoice2(model_path, load_jit=True, load_onnx=False, load_trt=False) 117 | return self.__cosyvoice 118 | 119 | @classmethod 120 | def INPUT_TYPES(s): 121 | return { 122 | "required": { 123 | "audio": ("AUDIO",), 124 | "speed": ("FLOAT", {"default": 1.0, "min": 0.5, "max": 1.5, "step": 0.1}), 125 | "text": ("STRING", {"multiline": True}), 126 | "instruct": ("STRING", {"multiline": True}), 127 | }, 128 | } 129 | 130 | RETURN_TYPES = ("AUDIO",) 131 | RETURN_NAMES = ("tts_speech",) 132 | FUNCTION = "main_func" 133 | CATEGORY = "Nineton Nodes" 134 | 135 | def main_func(self, audio, speed, text, instruct): 136 | waveform = audio["waveform"].squeeze(0) 137 | sample_rate = audio["sample_rate"] 138 | 139 | prompt_speech_16k = nt_load_wav(waveform, sample_rate, 16000) 140 | 141 | speechs = [] 142 | for i, j in enumerate(self.cosyvoice.inference_instruct2(tts_text=text, instruct_text=instruct, prompt_speech_16k=prompt_speech_16k, stream=False, speed=speed)): 143 | speechs.append(j['tts_speech']) 144 | 145 | tts_speech = torch.cat(speechs, dim=1) 146 | tts_speech = tts_speech.unsqueeze(0) 147 | outaudio = {"waveform": tts_speech, "sample_rate": self.cosyvoice.sample_rate} 148 | 149 | return (outaudio,) 150 | 151 | 152 | NODE_CLASS_MAPPINGS = { 153 | "NTCosyVoiceZeroShotSampler": NTCosyVoiceZeroShotSampler, 154 | "NTCosyVoiceInstruct2Sampler": NTCosyVoiceInstruct2Sampler, 155 | "NTCosyVoiceCrossLingualSampler": NTCosyVoiceCrossLingualSampler 156 | } 157 | 158 | __all__ = ['NODE_CLASS_MAPPINGS'] -------------------------------------------------------------------------------- /cosyvoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/__init__.py -------------------------------------------------------------------------------- /cosyvoice/bin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/bin/__init__.py -------------------------------------------------------------------------------- /cosyvoice/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import glob 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser(description='average model') 26 | parser.add_argument('--dst_model', required=True, help='averaged model') 27 | parser.add_argument('--src_path', 28 | required=True, 29 | help='src model path for average') 30 | parser.add_argument('--val_best', 31 | action="store_true", 32 | help='averaged model') 33 | parser.add_argument('--num', 34 | default=5, 35 | type=int, 36 | help='nums for averaged model') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | return args 41 | 42 | 43 | def main(): 44 | args = get_args() 45 | val_scores = [] 46 | if args.val_best: 47 | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) 48 | yamls = [ 49 | f for f in yamls 50 | if not (os.path.basename(f).startswith('train') 51 | or os.path.basename(f).startswith('init')) 52 | ] 53 | for y in yamls: 54 | with open(y, 'r') as f: 55 | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) 56 | loss = float(dic_yaml['loss_dict']['loss']) 57 | epoch = int(dic_yaml['epoch']) 58 | step = int(dic_yaml['step']) 59 | tag = dic_yaml['tag'] 60 | val_scores += [[epoch, step, loss, tag]] 61 | sorted_val_scores = sorted(val_scores, 62 | key=lambda x: x[2], 63 | reverse=False) 64 | print("best val (epoch, step, loss, tag) = " + 65 | str(sorted_val_scores[:args.num])) 66 | path_list = [ 67 | args.src_path + '/epoch_{}_whole.pt'.format(score[0]) 68 | for score in sorted_val_scores[:args.num] 69 | ] 70 | print(path_list) 71 | avg = {} 72 | num = args.num 73 | assert num == len(path_list) 74 | for path in path_list: 75 | print('Processing {}'.format(path)) 76 | states = torch.load(path, map_location=torch.device('cpu')) 77 | for k in states.keys(): 78 | if k not in avg.keys(): 79 | avg[k] = states[k].clone() 80 | else: 81 | avg[k] += states[k] 82 | # average 83 | for k in avg.keys(): 84 | if avg[k] is not None: 85 | # pytorch 1.6 use true_divide instead of /= 86 | avg[k] = torch.true_divide(avg[k], num) 87 | print('Saving to {}'.format(args.dst_model)) 88 | torch.save(avg, args.dst_model) 89 | 90 | 91 | if __name__ == '__main__': 92 | main() 93 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from cosyvoice.cli.cosyvoice import CosyVoice 27 | 28 | 29 | def get_args(): 30 | parser = argparse.ArgumentParser(description='export your model for deployment') 31 | parser.add_argument('--model_dir', 32 | type=str, 33 | default='pretrained_models/CosyVoice-300M', 34 | help='local path') 35 | args = parser.parse_args() 36 | print(args) 37 | return args 38 | 39 | 40 | def main(): 41 | args = get_args() 42 | logging.basicConfig(level=logging.DEBUG, 43 | format='%(asctime)s %(levelname)s %(message)s') 44 | 45 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 46 | torch._C._jit_set_profiling_mode(False) 47 | torch._C._jit_set_profiling_executor(False) 48 | 49 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) 50 | 51 | # 1. export llm text_encoder 52 | llm_text_encoder = cosyvoice.model.llm.text_encoder.half() 53 | script = torch.jit.script(llm_text_encoder) 54 | script = torch.jit.freeze(script) 55 | script = torch.jit.optimize_for_inference(script) 56 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 57 | 58 | # 2. export llm llm 59 | llm_llm = cosyvoice.model.llm.llm.half() 60 | script = torch.jit.script(llm_llm) 61 | script = torch.jit.freeze(script, preserved_attrs=['forward_chunk']) 62 | script = torch.jit.optimize_for_inference(script) 63 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 64 | 65 | # 3. export flow encoder 66 | flow_encoder = cosyvoice.model.flow.encoder 67 | script = torch.jit.script(flow_encoder) 68 | script = torch.jit.freeze(script) 69 | script = torch.jit.optimize_for_inference(script) 70 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from cosyvoice.cli.cosyvoice import CosyVoice 31 | 32 | 33 | def get_dummy_input(batch_size, seq_len, out_channels, device): 34 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 35 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 36 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 37 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 38 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 39 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 40 | return x, mask, mu, t, spks, cond 41 | 42 | 43 | def get_args(): 44 | parser = argparse.ArgumentParser(description='export your model for deployment') 45 | parser.add_argument('--model_dir', 46 | type=str, 47 | default='pretrained_models/CosyVoice-300M', 48 | help='local path') 49 | args = parser.parse_args() 50 | print(args) 51 | return args 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | logging.basicConfig(level=logging.DEBUG, 57 | format='%(asctime)s %(levelname)s %(message)s') 58 | 59 | cosyvoice = CosyVoice(args.model_dir, load_jit=False, load_onnx=False) 60 | 61 | # 1. export flow decoder estimator 62 | estimator = cosyvoice.model.flow.decoder.estimator 63 | 64 | device = cosyvoice.model.device 65 | batch_size, seq_len = 1, 256 66 | out_channels = cosyvoice.model.flow.decoder.estimator.out_channels 67 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 68 | torch.onnx.export( 69 | estimator, 70 | (x, mask, mu, t, spks, cond), 71 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 72 | export_params=True, 73 | opset_version=18, 74 | do_constant_folding=True, 75 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 76 | output_names=['estimator_out'], 77 | dynamic_axes={ 78 | 'x': {0: 'batch_size', 2: 'seq_len'}, 79 | 'mask': {0: 'batch_size', 2: 'seq_len'}, 80 | 'mu': {0: 'batch_size', 2: 'seq_len'}, 81 | 'cond': {0: 'batch_size', 2: 'seq_len'}, 82 | 't': {0: 'batch_size'}, 83 | 'spks': {0: 'batch_size'}, 84 | 'estimator_out': {0: 'batch_size', 2: 'seq_len'}, 85 | } 86 | ) 87 | 88 | # 2. test computation consistency 89 | option = onnxruntime.SessionOptions() 90 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 91 | option.intra_op_num_threads = 1 92 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 93 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 94 | sess_options=option, providers=providers) 95 | 96 | for _ in tqdm(range(10)): 97 | x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device) 98 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 99 | ort_inputs = { 100 | 'x': x.cpu().numpy(), 101 | 'mask': mask.cpu().numpy(), 102 | 'mu': mu.cpu().numpy(), 103 | 't': t.cpu().numpy(), 104 | 'spks': spks.cpu().numpy(), 105 | 'cond': cond.cpu().numpy() 106 | } 107 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 108 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_trt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | # download tensorrt from https://developer.nvidia.com/tensorrt/download/10x, check your system and cuda for compatibability 4 | # for example for linux + cuda12.4, you can download https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.0.1/tars/TensorRT-10.0.1.6.Linux.x86_64-gnu.cuda-12.4.tar.gz 5 | TRT_DIR= 6 | MODEL_DIR= 7 | 8 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TRT_DIR/lib:/usr/local/cuda/lib64 9 | $TRT_DIR/bin/trtexec --onnx=$MODEL_DIR/flow.decoder.estimator.fp32.onnx --saveEngine=$MODEL_DIR/flow.decoder.estimator.fp16.mygpu.plan --fp16 --minShapes=x:2x80x4,mask:2x1x4,mu:2x80x4,cond:2x80x4 --optShapes=x:2x80x193,mask:2x1x193,mu:2x80x193,cond:2x80x193 --maxShapes=x:2x80x6800,mask:2x1x6800,mu:2x80x6800,cond:2x80x6800 --inputIOFormats=fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw,fp16:chw --outputIOFormats=fp16:chw 10 | -------------------------------------------------------------------------------- /cosyvoice/bin/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import torchaudio 24 | from hyperpyyaml import load_hyperpyyaml 25 | from tqdm import tqdm 26 | from cosyvoice.cli.model import CosyVoiceModel 27 | from cosyvoice.dataset.dataset import Dataset 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='inference with your model') 32 | parser.add_argument('--config', required=True, help='config file') 33 | parser.add_argument('--prompt_data', required=True, help='prompt data file') 34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') 35 | parser.add_argument('--tts_text', required=True, help='tts input file') 36 | parser.add_argument('--llm_model', required=True, help='llm model file') 37 | parser.add_argument('--flow_model', required=True, help='flow model file') 38 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file') 39 | parser.add_argument('--gpu', 40 | type=int, 41 | default=-1, 42 | help='gpu id for this rank, -1 for cpu') 43 | parser.add_argument('--mode', 44 | default='sft', 45 | choices=['sft', 'zero_shot'], 46 | help='inference mode') 47 | parser.add_argument('--result_dir', required=True, help='asr result file') 48 | args = parser.parse_args() 49 | print(args) 50 | return args 51 | 52 | 53 | def main(): 54 | args = get_args() 55 | logging.basicConfig(level=logging.DEBUG, 56 | format='%(asctime)s %(levelname)s %(message)s') 57 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 58 | 59 | # Init cosyvoice models from configs 60 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 61 | device = torch.device('cuda' if use_cuda else 'cpu') 62 | with open(args.config, 'r') as f: 63 | configs = load_hyperpyyaml(f) 64 | 65 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) 66 | model.load(args.llm_model, args.flow_model, args.hifigan_model) 67 | 68 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, 69 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) 70 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) 71 | 72 | del configs 73 | os.makedirs(args.result_dir, exist_ok=True) 74 | fn = os.path.join(args.result_dir, 'wav.scp') 75 | f = open(fn, 'w') 76 | with torch.no_grad(): 77 | for _, batch in tqdm(enumerate(test_data_loader)): 78 | utts = batch["utts"] 79 | assert len(utts) == 1, "inference mode only support batchsize 1" 80 | text_token = batch["text_token"].to(device) 81 | text_token_len = batch["text_token_len"].to(device) 82 | tts_index = batch["tts_index"] 83 | tts_text_token = batch["tts_text_token"].to(device) 84 | tts_text_token_len = batch["tts_text_token_len"].to(device) 85 | speech_token = batch["speech_token"].to(device) 86 | speech_token_len = batch["speech_token_len"].to(device) 87 | speech_feat = batch["speech_feat"].to(device) 88 | speech_feat_len = batch["speech_feat_len"].to(device) 89 | utt_embedding = batch["utt_embedding"].to(device) 90 | spk_embedding = batch["spk_embedding"].to(device) 91 | if args.mode == 'sft': 92 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 93 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} 94 | else: 95 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 96 | 'prompt_text': text_token, 'prompt_text_len': text_token_len, 97 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, 98 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 99 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 100 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 101 | tts_speeches = [] 102 | for model_output in model.tts(**model_input): 103 | tts_speeches.append(model_output['tts_speech']) 104 | tts_speeches = torch.concat(tts_speeches, dim=1) 105 | tts_key = '{}_{}'.format(utts[0], tts_index[0]) 106 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) 107 | torchaudio.save(tts_fn, tts_speeches, sample_rate=22050) 108 | f.write('{} {}\n'.format(tts_key, tts_fn)) 109 | f.flush() 110 | f.close() 111 | logging.info('Result wav.scp saved in {}'.format(fn)) 112 | 113 | 114 | if __name__ == '__main__': 115 | main() 116 | -------------------------------------------------------------------------------- /cosyvoice/bin/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | import argparse 17 | import datetime 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | from copy import deepcopy 21 | import os 22 | import torch 23 | import torch.distributed as dist 24 | import deepspeed 25 | 26 | from hyperpyyaml import load_hyperpyyaml 27 | 28 | from torch.distributed.elastic.multiprocessing.errors import record 29 | 30 | from cosyvoice.utils.executor import Executor 31 | from cosyvoice.utils.train_utils import ( 32 | init_distributed, 33 | init_dataset_and_dataloader, 34 | init_optimizer_and_scheduler, 35 | init_summarywriter, save_model, 36 | wrap_cuda_model, check_modify_and_save_config) 37 | 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description='training your network') 41 | parser.add_argument('--train_engine', 42 | default='torch_ddp', 43 | choices=['torch_ddp', 'deepspeed'], 44 | help='Engine for paralleled training') 45 | parser.add_argument('--model', required=True, help='model which will be trained') 46 | parser.add_argument('--config', required=True, help='config file') 47 | parser.add_argument('--train_data', required=True, help='train data file') 48 | parser.add_argument('--cv_data', required=True, help='cv data file') 49 | parser.add_argument('--checkpoint', help='checkpoint model') 50 | parser.add_argument('--model_dir', required=True, help='save model dir') 51 | parser.add_argument('--tensorboard_dir', 52 | default='tensorboard', 53 | help='tensorboard log dir') 54 | parser.add_argument('--ddp.dist_backend', 55 | dest='dist_backend', 56 | default='nccl', 57 | choices=['nccl', 'gloo'], 58 | help='distributed backend') 59 | parser.add_argument('--num_workers', 60 | default=0, 61 | type=int, 62 | help='num of subprocess workers for reading') 63 | parser.add_argument('--prefetch', 64 | default=100, 65 | type=int, 66 | help='prefetch number') 67 | parser.add_argument('--pin_memory', 68 | action='store_true', 69 | default=False, 70 | help='Use pinned memory buffers used for reading') 71 | parser.add_argument('--use_amp', 72 | action='store_true', 73 | default=False, 74 | help='Use automatic mixed precision training') 75 | parser.add_argument('--deepspeed.save_states', 76 | dest='save_states', 77 | default='model_only', 78 | choices=['model_only', 'model+optimizer'], 79 | help='save model/optimizer states') 80 | parser.add_argument('--timeout', 81 | default=60, 82 | type=int, 83 | help='timeout (in seconds) of cosyvoice_join.') 84 | parser = deepspeed.add_config_arguments(parser) 85 | args = parser.parse_args() 86 | return args 87 | 88 | 89 | @record 90 | def main(): 91 | args = get_args() 92 | logging.basicConfig(level=logging.DEBUG, 93 | format='%(asctime)s %(levelname)s %(message)s') 94 | # gan train has some special initialization logic 95 | gan = True if args.model == 'hifigan' else False 96 | 97 | override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} 98 | if gan is True: 99 | override_dict.pop('hift') 100 | with open(args.config, 'r') as f: 101 | configs = load_hyperpyyaml(f, overrides=override_dict) 102 | if gan is True: 103 | configs['train_conf'] = configs['train_conf_gan'] 104 | configs['train_conf'].update(vars(args)) 105 | 106 | # Init env for ddp 107 | init_distributed(args) 108 | 109 | # Get dataset & dataloader 110 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ 111 | init_dataset_and_dataloader(args, configs, gan) 112 | 113 | # Do some sanity checks and save config to arsg.model_dir 114 | configs = check_modify_and_save_config(args, configs) 115 | 116 | # Tensorboard summary 117 | writer = init_summarywriter(args) 118 | 119 | # load checkpoint 120 | model = configs[args.model] 121 | start_step, start_epoch = 0, -1 122 | if args.checkpoint is not None: 123 | if os.path.exists(args.checkpoint): 124 | state_dict = torch.load(args.checkpoint, map_location='cpu') 125 | model.load_state_dict(state_dict, strict=False) 126 | if 'step' in state_dict: 127 | start_step = state_dict['step'] 128 | if 'epoch' in state_dict: 129 | start_epoch = state_dict['epoch'] 130 | else: 131 | logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) 132 | 133 | # Dispatch model from cpu to gpu 134 | model = wrap_cuda_model(args, model) 135 | 136 | # Get optimizer & scheduler 137 | model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) 138 | scheduler.set_step(start_step) 139 | if scheduler_d is not None: 140 | scheduler_d.set_step(start_step) 141 | 142 | # Save init checkpoints 143 | info_dict = deepcopy(configs['train_conf']) 144 | info_dict['step'] = start_step 145 | info_dict['epoch'] = start_epoch 146 | save_model(model, 'init', info_dict) 147 | 148 | # Get executor 149 | executor = Executor(gan=gan) 150 | executor.step = start_step 151 | 152 | # Init scaler, used for pytorch amp mixed precision training 153 | scaler = torch.cuda.amp.GradScaler() if args.use_amp else None 154 | print('start step {} start epoch {}'.format(start_step, start_epoch)) 155 | # Start training loop 156 | for epoch in range(start_epoch + 1, info_dict['max_epoch']): 157 | executor.epoch = epoch 158 | train_dataset.set_epoch(epoch) 159 | dist.barrier() 160 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) 161 | if gan is True: 162 | executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, 163 | writer, info_dict, scaler, group_join) 164 | else: 165 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join) 166 | dist.destroy_process_group(group_join) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/cli/__init__.py -------------------------------------------------------------------------------- /cosyvoice/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/dataset/__init__.py -------------------------------------------------------------------------------- /cosyvoice/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import json 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch.utils.data import IterableDataset 24 | from cosyvoice.utils.file_utils import read_lists, read_json_lists 25 | 26 | 27 | class Processor(IterableDataset): 28 | 29 | def __init__(self, source, f, *args, **kw): 30 | assert callable(f) 31 | self.source = source 32 | self.f = f 33 | self.args = args 34 | self.kw = kw 35 | 36 | def set_epoch(self, epoch): 37 | self.source.set_epoch(epoch) 38 | 39 | def __iter__(self): 40 | """ Return an iterator over the source dataset processed by the 41 | given processor. 42 | """ 43 | assert self.source is not None 44 | assert callable(self.f) 45 | return self.f(iter(self.source), *self.args, **self.kw) 46 | 47 | def apply(self, f): 48 | assert callable(f) 49 | return Processor(self, f, *self.args, **self.kw) 50 | 51 | 52 | class DistributedSampler: 53 | 54 | def __init__(self, shuffle=True, partition=True): 55 | self.epoch = -1 56 | self.update() 57 | self.shuffle = shuffle 58 | self.partition = partition 59 | 60 | def update(self): 61 | assert dist.is_available() 62 | if dist.is_initialized(): 63 | self.rank = dist.get_rank() 64 | self.world_size = dist.get_world_size() 65 | else: 66 | self.rank = 0 67 | self.world_size = 1 68 | worker_info = torch.utils.data.get_worker_info() 69 | if worker_info is None: 70 | self.worker_id = 0 71 | self.num_workers = 1 72 | else: 73 | self.worker_id = worker_info.id 74 | self.num_workers = worker_info.num_workers 75 | return dict(rank=self.rank, 76 | world_size=self.world_size, 77 | worker_id=self.worker_id, 78 | num_workers=self.num_workers) 79 | 80 | def set_epoch(self, epoch): 81 | self.epoch = epoch 82 | 83 | def sample(self, data): 84 | """ Sample data according to rank/world_size/num_workers 85 | 86 | Args: 87 | data(List): input data list 88 | 89 | Returns: 90 | List: data list after sample 91 | """ 92 | data = list(range(len(data))) 93 | # force datalist even 94 | if self.partition: 95 | if self.shuffle: 96 | random.Random(self.epoch).shuffle(data) 97 | if len(data) < self.world_size: 98 | data = data * math.ceil(self.world_size / len(data)) 99 | data = data[:self.world_size] 100 | data = data[self.rank::self.world_size] 101 | if len(data) < self.num_workers: 102 | data = data * math.ceil(self.num_workers / len(data)) 103 | data = data[:self.num_workers] 104 | data = data[self.worker_id::self.num_workers] 105 | return data 106 | 107 | 108 | class DataList(IterableDataset): 109 | 110 | def __init__(self, lists, shuffle=True, partition=True): 111 | self.lists = lists 112 | self.sampler = DistributedSampler(shuffle, partition) 113 | 114 | def set_epoch(self, epoch): 115 | self.sampler.set_epoch(epoch) 116 | 117 | def __iter__(self): 118 | sampler_info = self.sampler.update() 119 | indexes = self.sampler.sample(self.lists) 120 | for index in indexes: 121 | data = dict(src=self.lists[index]) 122 | data.update(sampler_info) 123 | yield data 124 | 125 | 126 | def Dataset(data_list_file, 127 | data_pipeline, 128 | mode='train', 129 | gan=False, 130 | shuffle=True, 131 | partition=True, 132 | tts_file='', 133 | prompt_utt2data=''): 134 | """ Construct dataset from arguments 135 | 136 | We have two shuffle stage in the Dataset. The first is global 137 | shuffle at shards tar/raw file level. The second is global shuffle 138 | at training samples level. 139 | 140 | Args: 141 | data_type(str): raw/shard 142 | tokenizer (BaseTokenizer): tokenizer to tokenize 143 | partition(bool): whether to do data partition in terms of rank 144 | """ 145 | assert mode in ['train', 'inference'] 146 | lists = read_lists(data_list_file) 147 | if mode == 'inference': 148 | with open(tts_file) as f: 149 | tts_data = json.load(f) 150 | utt2lists = read_json_lists(prompt_utt2data) 151 | # filter unnecessary file in inference mode 152 | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists}) 153 | dataset = DataList(lists, 154 | shuffle=shuffle, 155 | partition=partition) 156 | if mode == 'inference': 157 | # map partial arg to parquet_opener func in inference mode 158 | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data) 159 | if gan is True: 160 | # map partial arg to padding func in gan mode 161 | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan) 162 | for func in data_pipeline: 163 | dataset = Processor(dataset, func, mode=mode) 164 | return dataset 165 | -------------------------------------------------------------------------------- /cosyvoice/flow/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/flow/__init__.py -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # x in (B, T, D) 55 | if x2.shape[1] > 40: 56 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 57 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 58 | mode='linear') 59 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 60 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 61 | else: 62 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 63 | if x1.shape[1] != 0: 64 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 65 | x = torch.concat([x1, x2], dim=2) 66 | else: 67 | x = x2 68 | out = self.model(x).transpose(1, 2).contiguous() 69 | return out, mel_len1 + mel_len2 70 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/hifigan/__init__.py -------------------------------------------------------------------------------- /cosyvoice/hifigan/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils import weight_norm 4 | from typing import List, Optional, Tuple 5 | from einops import rearrange 6 | from torchaudio.transforms import Spectrogram 7 | 8 | 9 | class MultipleDiscriminator(nn.Module): 10 | def __init__( 11 | self, mpd: nn.Module, mrd: nn.Module 12 | ): 13 | super().__init__() 14 | self.mpd = mpd 15 | self.mrd = mrd 16 | 17 | def forward(self, y: torch.Tensor, y_hat: torch.Tensor): 18 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] 19 | this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1)) 20 | y_d_rs += this_y_d_rs 21 | y_d_gs += this_y_d_gs 22 | fmap_rs += this_fmap_rs 23 | fmap_gs += this_fmap_gs 24 | this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat) 25 | y_d_rs += this_y_d_rs 26 | y_d_gs += this_y_d_gs 27 | fmap_rs += this_fmap_rs 28 | fmap_gs += this_fmap_gs 29 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 30 | 31 | 32 | class MultiResolutionDiscriminator(nn.Module): 33 | def __init__( 34 | self, 35 | fft_sizes: Tuple[int, ...] = (2048, 1024, 512), 36 | num_embeddings: Optional[int] = None, 37 | ): 38 | """ 39 | Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec. 40 | Additionally, it allows incorporating conditional information with a learned embeddings table. 41 | 42 | Args: 43 | fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512). 44 | num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator. 45 | Defaults to None. 46 | """ 47 | 48 | super().__init__() 49 | self.discriminators = nn.ModuleList( 50 | [DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes] 51 | ) 52 | 53 | def forward( 54 | self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None 55 | ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]: 56 | y_d_rs = [] 57 | y_d_gs = [] 58 | fmap_rs = [] 59 | fmap_gs = [] 60 | 61 | for d in self.discriminators: 62 | y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id) 63 | y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id) 64 | y_d_rs.append(y_d_r) 65 | fmap_rs.append(fmap_r) 66 | y_d_gs.append(y_d_g) 67 | fmap_gs.append(fmap_g) 68 | 69 | return y_d_rs, y_d_gs, fmap_rs, fmap_gs 70 | 71 | 72 | class DiscriminatorR(nn.Module): 73 | def __init__( 74 | self, 75 | window_length: int, 76 | num_embeddings: Optional[int] = None, 77 | channels: int = 32, 78 | hop_factor: float = 0.25, 79 | bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)), 80 | ): 81 | super().__init__() 82 | self.window_length = window_length 83 | self.hop_factor = hop_factor 84 | self.spec_fn = Spectrogram( 85 | n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None 86 | ) 87 | n_fft = window_length // 2 + 1 88 | bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] 89 | self.bands = bands 90 | convs = lambda: nn.ModuleList( 91 | [ 92 | weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))), 93 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 94 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 95 | weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))), 96 | weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))), 97 | ] 98 | ) 99 | self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) 100 | 101 | if num_embeddings is not None: 102 | self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels) 103 | torch.nn.init.zeros_(self.emb.weight) 104 | 105 | self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))) 106 | 107 | def spectrogram(self, x): 108 | # Remove DC offset 109 | x = x - x.mean(dim=-1, keepdims=True) 110 | # Peak normalize the volume of input audio 111 | x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9) 112 | x = self.spec_fn(x) 113 | x = torch.view_as_real(x) 114 | x = rearrange(x, "b f t c -> b c t f") 115 | # Split into bands 116 | x_bands = [x[..., b[0]: b[1]] for b in self.bands] 117 | return x_bands 118 | 119 | def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None): 120 | x_bands = self.spectrogram(x) 121 | fmap = [] 122 | x = [] 123 | for band, stack in zip(x_bands, self.band_convs): 124 | for i, layer in enumerate(stack): 125 | band = layer(band) 126 | band = torch.nn.functional.leaky_relu(band, 0.1) 127 | if i > 0: 128 | fmap.append(band) 129 | x.append(band) 130 | x = torch.cat(x, dim=-1) 131 | if cond_embedding_id is not None: 132 | emb = self.emb(cond_embedding_id) 133 | h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True) 134 | else: 135 | h = 0 136 | x = self.conv_post(x) 137 | fmap.append(x) 138 | x += h 139 | 140 | return x, fmap 141 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn.utils import weight_norm 17 | 18 | 19 | class ConvRNNF0Predictor(nn.Module): 20 | def __init__(self, 21 | num_class: int = 1, 22 | in_channels: int = 80, 23 | cond_channels: int = 512 24 | ): 25 | super().__init__() 26 | 27 | self.num_class = num_class 28 | self.condnet = nn.Sequential( 29 | weight_norm( 30 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 31 | ), 32 | nn.ELU(), 33 | weight_norm( 34 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 35 | ), 36 | nn.ELU(), 37 | weight_norm( 38 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 39 | ), 40 | nn.ELU(), 41 | weight_norm( 42 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 43 | ), 44 | nn.ELU(), 45 | weight_norm( 46 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 47 | ), 48 | nn.ELU(), 49 | ) 50 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.condnet(x) 54 | x = x.transpose(1, 2) 55 | return torch.abs(self.classifier(x).squeeze(-1)) 56 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss 6 | from cosyvoice.utils.losses import tpr_loss, mel_loss 7 | 8 | 9 | class HiFiGan(nn.Module): 10 | def __init__(self, generator, discriminator, mel_spec_transform, 11 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, 12 | tpr_loss_weight=1.0, tpr_loss_tau=0.04): 13 | super(HiFiGan, self).__init__() 14 | self.generator = generator 15 | self.discriminator = discriminator 16 | self.mel_spec_transform = mel_spec_transform 17 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight 18 | self.feat_match_loss_weight = feat_match_loss_weight 19 | self.tpr_loss_weight = tpr_loss_weight 20 | self.tpr_loss_tau = tpr_loss_tau 21 | 22 | def forward( 23 | self, 24 | batch: dict, 25 | device: torch.device, 26 | ) -> Dict[str, Optional[torch.Tensor]]: 27 | if batch['turn'] == 'generator': 28 | return self.forward_generator(batch, device) 29 | else: 30 | return self.forward_discriminator(batch, device) 31 | 32 | def forward_generator(self, batch, device): 33 | real_speech = batch['speech'].to(device) 34 | pitch_feat = batch['pitch_feat'].to(device) 35 | # 1. calculate generator outputs 36 | generated_speech, generated_f0 = self.generator(batch, device) 37 | # 2. calculate discriminator outputs 38 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 39 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] 40 | loss_gen, _ = generator_loss(y_d_gs) 41 | loss_fm = feature_loss(fmap_rs, fmap_gs) 42 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) 43 | if self.tpr_loss_weight != 0: 44 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 45 | else: 46 | loss_tpr = torch.zeros(1).to(device) 47 | loss_f0 = F.l1_loss(generated_f0, pitch_feat) 48 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ 49 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \ 50 | self.tpr_loss_weight * loss_tpr + loss_f0 51 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} 52 | 53 | def forward_discriminator(self, batch, device): 54 | real_speech = batch['speech'].to(device) 55 | # 1. calculate generator outputs 56 | with torch.no_grad(): 57 | generated_speech, generated_f0 = self.generator(batch, device) 58 | # 2. calculate discriminator outputs 59 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 60 | # 3. calculate discriminator losses, tpr losses [Optional] 61 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) 62 | if self.tpr_loss_weight != 0: 63 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 64 | else: 65 | loss_tpr = torch.zeros(1).to(device) 66 | loss = loss_disc + self.tpr_loss_weight * loss_tpr 67 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} 68 | -------------------------------------------------------------------------------- /cosyvoice/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/llm/__init__.py -------------------------------------------------------------------------------- /cosyvoice/tokenizer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/tokenizer/__init__.py -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/transformer/__init__.py -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /cosyvoice/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /cosyvoice/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /cosyvoice/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /cosyvoice/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/cosyvoice/utils/__init__.py -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from cosyvoice.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding 34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling 35 | from cosyvoice.llm.llm import TransformerLM, Qwen2LM 36 | from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec 37 | from cosyvoice.hifigan.generator import HiFTGenerator 38 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 39 | 40 | 41 | COSYVOICE_ACTIVATION_CLASSES = { 42 | "hardtanh": torch.nn.Hardtanh, 43 | "tanh": torch.nn.Tanh, 44 | "relu": torch.nn.ReLU, 45 | "selu": torch.nn.SELU, 46 | "swish": getattr(torch.nn, "SiLU", Swish), 47 | "gelu": torch.nn.GELU, 48 | } 49 | 50 | COSYVOICE_SUBSAMPLE_CLASSES = { 51 | "linear": LinearNoSubsampling, 52 | "linear_legacy": LegacyLinearNoSubsampling, 53 | "embed": EmbedinigNoSubsampling, 54 | "conv1d2": Conv1dSubsampling2, 55 | "conv2d": Conv2dSubsampling4, 56 | "conv2d6": Conv2dSubsampling6, 57 | "conv2d8": Conv2dSubsampling8, 58 | 'paraformer_dummy': torch.nn.Identity 59 | } 60 | 61 | COSYVOICE_EMB_CLASSES = { 62 | "embed": PositionalEncoding, 63 | "abs_pos": PositionalEncoding, 64 | "rel_pos": RelPositionalEncoding, 65 | "rel_pos_espnet": EspnetRelPositionalEncoding, 66 | "no_pos": NoPositionalEncoding, 67 | "abs_pos_whisper": WhisperPositionalEncoding, 68 | "embed_learnable_pe": LearnablePositionalEncoding, 69 | } 70 | 71 | COSYVOICE_ATTENTION_CLASSES = { 72 | "selfattn": MultiHeadedAttention, 73 | "rel_selfattn": RelPositionMultiHeadedAttention, 74 | } 75 | 76 | 77 | def get_model_type(configs): 78 | if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 79 | return CosyVoiceModel 80 | if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 81 | return CosyVoice2Model 82 | raise TypeError('No valid model type found!') 83 | -------------------------------------------------------------------------------- /cosyvoice/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """Unility functions for Transformer.""" 17 | 18 | import random 19 | from typing import List 20 | 21 | import numpy as np 22 | import torch 23 | 24 | IGNORE_ID = -1 25 | 26 | 27 | def pad_list(xs: List[torch.Tensor], pad_value: int): 28 | """Perform padding for the list of tensors. 29 | 30 | Args: 31 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 32 | pad_value (float): Value for padding. 33 | 34 | Returns: 35 | Tensor: Padded tensor (B, Tmax, `*`). 36 | 37 | Examples: 38 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 39 | >>> x 40 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 41 | >>> pad_list(x, 0) 42 | tensor([[1., 1., 1., 1.], 43 | [1., 1., 0., 0.], 44 | [1., 0., 0., 0.]]) 45 | 46 | """ 47 | max_len = max([len(item) for item in xs]) 48 | batchs = len(xs) 49 | ndim = xs[0].ndim 50 | if ndim == 1: 51 | pad_res = torch.zeros(batchs, 52 | max_len, 53 | dtype=xs[0].dtype, 54 | device=xs[0].device) 55 | elif ndim == 2: 56 | pad_res = torch.zeros(batchs, 57 | max_len, 58 | xs[0].shape[1], 59 | dtype=xs[0].dtype, 60 | device=xs[0].device) 61 | elif ndim == 3: 62 | pad_res = torch.zeros(batchs, 63 | max_len, 64 | xs[0].shape[1], 65 | xs[0].shape[2], 66 | dtype=xs[0].dtype, 67 | device=xs[0].device) 68 | else: 69 | raise ValueError(f"Unsupported ndim: {ndim}") 70 | pad_res.fill_(pad_value) 71 | for i in range(batchs): 72 | pad_res[i, :len(xs[i])] = xs[i] 73 | return pad_res 74 | 75 | 76 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 77 | ignore_label: int) -> torch.Tensor: 78 | """Calculate accuracy. 79 | 80 | Args: 81 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 82 | pad_targets (LongTensor): Target label tensors (B, Lmax). 83 | ignore_label (int): Ignore label id. 84 | 85 | Returns: 86 | torch.Tensor: Accuracy value (0.0 - 1.0). 87 | 88 | """ 89 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 90 | pad_outputs.size(1)).argmax(2) 91 | mask = pad_targets != ignore_label 92 | numerator = torch.sum( 93 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 94 | denominator = torch.sum(mask) 95 | return (numerator / denominator).detach() 96 | 97 | 98 | def get_padding(kernel_size, dilation=1): 99 | return int((kernel_size * dilation - dilation) / 2) 100 | 101 | 102 | def init_weights(m, mean=0.0, std=0.01): 103 | classname = m.__class__.__name__ 104 | if classname.find("Conv") != -1: 105 | m.weight.data.normal_(mean, std) 106 | 107 | 108 | # Repetition Aware Sampling in VALL-E 2 109 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): 110 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) 111 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() 112 | if rep_num >= win_size * tau_r: 113 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) 114 | return top_ids 115 | 116 | 117 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): 118 | prob, indices = [], [] 119 | cum_prob = 0.0 120 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) 121 | for i in range(len(sorted_idx)): 122 | # sampling both top-p and numbers. 123 | if cum_prob < top_p and len(prob) < top_k: 124 | cum_prob += sorted_value[i] 125 | prob.append(sorted_value[i]) 126 | indices.append(sorted_idx[i]) 127 | else: 128 | break 129 | prob = torch.tensor(prob).to(weighted_scores) 130 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) 131 | top_ids = indices[prob.multinomial(1, replacement=True)] 132 | return top_ids 133 | 134 | 135 | def random_sampling(weighted_scores, decoded_tokens, sampling): 136 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) 137 | return top_ids 138 | 139 | 140 | def fade_in_out(fade_in_mel, fade_out_mel, window): 141 | device = fade_in_mel.device 142 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() 143 | mel_overlap_len = int(window.shape[0] / 2) 144 | if fade_in_mel.device == torch.device('cpu'): 145 | fade_in_mel = fade_in_mel.clone() 146 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ 147 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] 148 | return fade_in_mel.to(device) 149 | 150 | 151 | def set_all_random_seed(seed): 152 | random.seed(seed) 153 | np.random.seed(seed) 154 | torch.manual_seed(seed) 155 | torch.cuda.manual_seed_all(seed) 156 | 157 | 158 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 159 | assert mask.dtype == torch.bool 160 | assert dtype in [torch.float32, torch.bfloat16, torch.float16] 161 | mask = mask.to(dtype) 162 | # attention mask bias 163 | # NOTE(Mddct): torch.finfo jit issues 164 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min 165 | mask = (1.0 - mask) * torch.finfo(dtype).min 166 | return mask 167 | -------------------------------------------------------------------------------- /cosyvoice/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import torchaudio 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | logging.basicConfig(level=logging.DEBUG, 21 | format='%(asctime)s %(levelname)s %(message)s') 22 | 23 | 24 | def read_lists(list_file): 25 | lists = [] 26 | with open(list_file, 'r', encoding='utf8') as fin: 27 | for line in fin: 28 | lists.append(line.strip()) 29 | return lists 30 | 31 | 32 | def read_json_lists(list_file): 33 | lists = read_lists(list_file) 34 | results = {} 35 | for fn in lists: 36 | with open(fn, 'r', encoding='utf8') as fin: 37 | results.update(json.load(fin)) 38 | return results 39 | 40 | 41 | def load_wav(wav, target_sr): 42 | speech, sample_rate = torchaudio.load(wav) 43 | speech = speech.mean(dim=0, keepdim=True) 44 | if sample_rate != target_sr: 45 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 46 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 47 | return speech 48 | -------------------------------------------------------------------------------- /cosyvoice/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import regex 17 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 18 | 19 | 20 | # whether contain chinese character 21 | def contains_chinese(text): 22 | return bool(chinese_char_pattern.search(text)) 23 | 24 | 25 | # replace special symbol 26 | def replace_corner_mark(text): 27 | text = text.replace('²', '平方') 28 | text = text.replace('³', '立方') 29 | return text 30 | 31 | 32 | # remove meaningless symbol 33 | def remove_bracket(text): 34 | text = text.replace('(', '').replace(')', '') 35 | text = text.replace('【', '').replace('】', '') 36 | text = text.replace('`', '').replace('`', '') 37 | text = text.replace("——", " ") 38 | return text 39 | 40 | 41 | # spell Arabic numerals 42 | def spell_out_number(text: str, inflect_parser): 43 | new_text = [] 44 | st = None 45 | for i, c in enumerate(text): 46 | if not c.isdigit(): 47 | if st is not None: 48 | num_str = inflect_parser.number_to_words(text[st: i]) 49 | new_text.append(num_str) 50 | st = None 51 | new_text.append(c) 52 | else: 53 | if st is None: 54 | st = i 55 | if st is not None and st < len(text): 56 | num_str = inflect_parser.number_to_words(text[st:]) 57 | new_text.append(num_str) 58 | return ''.join(new_text) 59 | 60 | 61 | # split paragrah logic: 62 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 63 | # 2. cal sentence len according to lang 64 | # 3. split sentence according to puncatation 65 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 66 | def calc_utt_length(_text: str): 67 | if lang == "zh": 68 | return len(_text) 69 | else: 70 | return len(tokenize(_text)) 71 | 72 | def should_merge(_text: str): 73 | if lang == "zh": 74 | return len(_text) < merge_len 75 | else: 76 | return len(tokenize(_text)) < merge_len 77 | 78 | if lang == "zh": 79 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 80 | else: 81 | pounc = ['.', '?', '!', ';', ':'] 82 | if comma_split: 83 | pounc.extend([',', ',']) 84 | 85 | if text[-1] not in pounc: 86 | if lang == "zh": 87 | text += "。" 88 | else: 89 | text += "." 90 | 91 | st = 0 92 | utts = [] 93 | for i, c in enumerate(text): 94 | if c in pounc: 95 | if len(text[st: i]) > 0: 96 | utts.append(text[st: i] + c) 97 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 98 | tmp = utts.pop(-1) 99 | utts.append(tmp + text[i + 1]) 100 | st = i + 2 101 | else: 102 | st = i + 1 103 | 104 | final_utts = [] 105 | cur_utt = "" 106 | for utt in utts: 107 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 108 | final_utts.append(cur_utt) 109 | cur_utt = "" 110 | cur_utt = cur_utt + utt 111 | if len(cur_utt) > 0: 112 | if should_merge(cur_utt) and len(final_utts) != 0: 113 | final_utts[-1] = final_utts[-1] + cur_utt 114 | else: 115 | final_utts.append(cur_utt) 116 | 117 | return final_utts 118 | 119 | 120 | # remove blank between chinese character 121 | def replace_blank(text: str): 122 | out_str = [] 123 | for i, c in enumerate(text): 124 | if c == " ": 125 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 126 | (text[i - 1].isascii() and text[i - 1] != " ")): 127 | out_str.append(c) 128 | else: 129 | out_str.append(c) 130 | return "".join(out_str) 131 | 132 | 133 | def is_only_punctuation(text): 134 | # Regular expression: Match strings that consist only of punctuation marks or are empty. 135 | punctuation_pattern = r'^[\p{P}\p{S}]*$' 136 | return bool(regex.fullmatch(punctuation_pattern, text)) 137 | -------------------------------------------------------------------------------- /cosyvoice/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 6 | loss = 0 7 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 8 | m_DG = torch.median((dr - dg)) 9 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 10 | loss += tau - F.relu(tau - L_rel) 11 | return loss 12 | 13 | 14 | def mel_loss(real_speech, generated_speech, mel_transforms): 15 | loss = 0 16 | for transform in mel_transforms: 17 | mel_r = transform(real_speech) 18 | mel_g = transform(generated_speech) 19 | loss += F.l1_loss(mel_g, mel_r) 20 | return loss 21 | -------------------------------------------------------------------------------- /downloadmodel.py: -------------------------------------------------------------------------------- 1 | from modelscope import snapshot_download 2 | snapshot_download('chenmingyu/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B') 3 | # snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M') 4 | # snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz') 5 | # snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT') 6 | # snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct') 7 | snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd') -------------------------------------------------------------------------------- /examples/CrossLingual.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 11, 3 | "last_link_id": 11, 4 | "nodes": [ 5 | { 6 | "id": 1, 7 | "type": "LoadAudio", 8 | "pos": [ 9 | -199.86219787597656, 10 | -357.5240173339844 11 | ], 12 | "size": [ 13 | 338.2475891113281, 14 | 126.5556869506836 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [ 21 | { 22 | "name": "AUDIO", 23 | "type": "AUDIO", 24 | "links": [ 25 | 5 26 | ], 27 | "slot_index": 0 28 | } 29 | ], 30 | "properties": { 31 | "Node name for S&R": "LoadAudio" 32 | }, 33 | "widgets_values": [ 34 | "", 35 | null, 36 | "" 37 | ] 38 | }, 39 | { 40 | "id": 7, 41 | "type": "NTCosyVoiceCrossLingualSampler", 42 | "pos": [ 43 | 289.2545471191406, 44 | -318.1132507324219 45 | ], 46 | "size": [ 47 | 518.6781005859375, 48 | 169.7455596923828 49 | ], 50 | "flags": {}, 51 | "order": 1, 52 | "mode": 0, 53 | "inputs": [ 54 | { 55 | "name": "audio", 56 | "type": "AUDIO", 57 | "link": 5 58 | } 59 | ], 60 | "outputs": [ 61 | { 62 | "name": "tts_speech", 63 | "type": "AUDIO", 64 | "links": [ 65 | 6 66 | ], 67 | "slot_index": 0 68 | } 69 | ], 70 | "properties": { 71 | "Node name for S&R": "NTCosyVoiceCrossLingualSampler" 72 | }, 73 | "widgets_values": [ 74 | 1.1, 75 | "Logic is the order of thought and the reasoning of reason. Everything we see in the world has its internal order and order, and this is where logic lies. In the field of thinking, logic is our weapon to grasp truth and distinguish right from wrong. It teaches us to organize our thoughts in a clear way, to support our arguments with conclusive evidence, and to arrive at rational conclusions. Specifically, logic includes two categories: deductive reasoning and inductive reasoning. The deductive reasoner deduces special conclusions from general principles by general and individual methods. The process is like the derivation of unknown effects from known causes, step by step, impeccable. Inductive reasoning does the opposite, from the individual and the general, and extracts general laws from particular facts. These two complement each other and constitute the core of logic. Our scholars study logic in order to develop clear thinking, rigorous attitude and accurate expression. As individuals, logic enables us to better understand the world and grasp ourselves. For society, logic is the cornerstone of promoting civilization and progress and maintaining fairness and justice. Therefore, the study of logic is the indispensable wisdom of life." 76 | ] 77 | }, 78 | { 79 | "id": 8, 80 | "type": "PreviewAudio", 81 | "pos": [ 82 | 953.0675659179688, 83 | -304.42449951171875 84 | ], 85 | "size": [ 86 | 433.8606262207031, 87 | 167.7429962158203 88 | ], 89 | "flags": {}, 90 | "order": 2, 91 | "mode": 0, 92 | "inputs": [ 93 | { 94 | "name": "audio", 95 | "type": "AUDIO", 96 | "link": 6 97 | } 98 | ], 99 | "outputs": [], 100 | "properties": { 101 | "Node name for S&R": "PreviewAudio" 102 | }, 103 | "widgets_values": [ 104 | null 105 | ] 106 | } 107 | ], 108 | "links": [ 109 | [ 110 | 5, 111 | 1, 112 | 0, 113 | 7, 114 | 0, 115 | "AUDIO" 116 | ], 117 | [ 118 | 6, 119 | 7, 120 | 0, 121 | 8, 122 | 0, 123 | "AUDIO" 124 | ] 125 | ], 126 | "groups": [], 127 | "config": {}, 128 | "extra": { 129 | "ds": { 130 | "scale": 1.2839025177495025, 131 | "offset": [ 132 | 493.2809123871739, 133 | 820.31454385171 134 | ] 135 | } 136 | }, 137 | "version": 0.4 138 | } -------------------------------------------------------------------------------- /examples/Instruct2.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 5, 3 | "last_link_id": 4, 4 | "nodes": [ 5 | { 6 | "id": 1, 7 | "type": "LoadAudio", 8 | "pos": [ 9 | -81.06248474121094, 10 | -627.4505615234375 11 | ], 12 | "size": [ 13 | 338.2475891113281, 14 | 126.5556869506836 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "inputs": [], 20 | "outputs": [ 21 | { 22 | "name": "AUDIO", 23 | "type": "AUDIO", 24 | "links": [ 25 | 3 26 | ], 27 | "slot_index": 0 28 | } 29 | ], 30 | "properties": { 31 | "Node name for S&R": "LoadAudio" 32 | }, 33 | "widgets_values": [ 34 | "", 35 | null, 36 | "" 37 | ] 38 | }, 39 | { 40 | "id": 3, 41 | "type": "PreviewAudio", 42 | "pos": [ 43 | 888.1553344726562, 44 | -627.6224365234375 45 | ], 46 | "size": [ 47 | 433.8606262207031, 48 | 167.7429962158203 49 | ], 50 | "flags": {}, 51 | "order": 2, 52 | "mode": 0, 53 | "inputs": [ 54 | { 55 | "name": "audio", 56 | "type": "AUDIO", 57 | "link": 4 58 | } 59 | ], 60 | "outputs": [], 61 | "properties": { 62 | "Node name for S&R": "PreviewAudio" 63 | }, 64 | "widgets_values": [ 65 | null 66 | ] 67 | }, 68 | { 69 | "id": 5, 70 | "type": "NTCosyVoiceInstruct2Sampler", 71 | "pos": [ 72 | 357.6646728515625, 73 | -628.6502075195312 74 | ], 75 | "size": [ 76 | 483.8872985839844, 77 | 235.1954345703125 78 | ], 79 | "flags": {}, 80 | "order": 1, 81 | "mode": 0, 82 | "inputs": [ 83 | { 84 | "name": "audio", 85 | "type": "AUDIO", 86 | "link": 3 87 | } 88 | ], 89 | "outputs": [ 90 | { 91 | "name": "tts_speech", 92 | "type": "AUDIO", 93 | "links": [ 94 | 4 95 | ], 96 | "slot_index": 0 97 | } 98 | ], 99 | "properties": { 100 | "Node name for S&R": "NTCosyVoiceInstruct2Sampler" 101 | }, 102 | "widgets_values": [ 103 | 1, 104 | "逻辑者,乃思维之秩序与理性之推理也。吾观世间万物,皆有其内在之条理与次序,此即逻辑之所在。在思辨之域,逻辑乃是我们把握真理、辨明是非之利器。它教我们以条理清晰之方式组织思想,以确凿无疑之论据支持论点,从而得出合乎理性之结论。具体而言,逻辑包括演绎推理与归纳推理两大范畴。演绎推理者,由一般而个别之方法也,自普遍之原则推导出特殊之结论。其过程犹如由已知之因推导出未知之果,步步为营,无懈可击。归纳推理则反其道而行之,由个别而一般,从特殊之事实中提炼出普遍之规律。此二者相辅相成,共同构成了逻辑学之核心。吾辈学者研习逻辑,旨在培养清晰之思维、严谨之态度与准确之表达能力。于个人而言,逻辑使我们能更好地理解世界、把握自身;于社会而言,逻辑则是促进文明进步、维护公平正义之基石。故逻辑之学,实乃人生不可或缺之智慧也。", 105 | "请用悲伤的语气阅读这段话" 106 | ] 107 | } 108 | ], 109 | "links": [ 110 | [ 111 | 3, 112 | 1, 113 | 0, 114 | 5, 115 | 0, 116 | "AUDIO" 117 | ], 118 | [ 119 | 4, 120 | 5, 121 | 0, 122 | 3, 123 | 0, 124 | "AUDIO" 125 | ] 126 | ], 127 | "groups": [], 128 | "config": {}, 129 | "extra": { 130 | "ds": { 131 | "scale": 1.2839025177495025, 132 | "offset": [ 133 | 156.24695224003256, 134 | 872.3227271653814 135 | ] 136 | } 137 | }, 138 | "version": 0.4 139 | } -------------------------------------------------------------------------------- /examples/ZeroShot.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 4, 3 | "last_link_id": 2, 4 | "nodes": [ 5 | { 6 | "id": 3, 7 | "type": "PreviewAudio", 8 | "pos": [ 9 | 934.6931762695312, 10 | -636.467041015625 11 | ], 12 | "size": [ 13 | 433.8606262207031, 14 | 167.7429962158203 15 | ], 16 | "flags": {}, 17 | "order": 2, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "audio", 22 | "type": "AUDIO", 23 | "link": 1 24 | } 25 | ], 26 | "outputs": [], 27 | "properties": { 28 | "Node name for S&R": "PreviewAudio" 29 | }, 30 | "widgets_values": [ 31 | null 32 | ] 33 | }, 34 | { 35 | "id": 1, 36 | "type": "LoadAudio", 37 | "pos": [ 38 | -81.06248474121094, 39 | -627.4505615234375 40 | ], 41 | "size": [ 42 | 384.96185302734375, 43 | 194.6464385986328 44 | ], 45 | "flags": {}, 46 | "order": 0, 47 | "mode": 0, 48 | "inputs": [], 49 | "outputs": [ 50 | { 51 | "name": "AUDIO", 52 | "type": "AUDIO", 53 | "links": [ 54 | 2 55 | ], 56 | "slot_index": 0 57 | } 58 | ], 59 | "properties": { 60 | "Node name for S&R": "LoadAudio" 61 | }, 62 | "widgets_values": [ 63 | "", 64 | null, 65 | "" 66 | ] 67 | }, 68 | { 69 | "id": 2, 70 | "type": "NTCosyVoiceZeroShotSampler", 71 | "pos": [ 72 | 353.0162048339844, 73 | -625.3529052734375 74 | ], 75 | "size": [ 76 | 492.1404113769531, 77 | 240.8211212158203 78 | ], 79 | "flags": {}, 80 | "order": 1, 81 | "mode": 0, 82 | "inputs": [ 83 | { 84 | "name": "audio", 85 | "type": "AUDIO", 86 | "link": 2 87 | } 88 | ], 89 | "outputs": [ 90 | { 91 | "name": "tts_speech", 92 | "type": "AUDIO", 93 | "links": [ 94 | 1 95 | ], 96 | "slot_index": 0 97 | } 98 | ], 99 | "properties": { 100 | "Node name for S&R": "NTCosyVoiceZeroShotSampler" 101 | }, 102 | "widgets_values": [ 103 | 1, 104 | "逻辑者,乃思维之秩序与理性之推理也。吾观世间万物,皆有其内在之条理与次序,此即逻辑之所在。在思辨之域,逻辑乃是我们把握真理、辨明是非之利器。它教我们以条理清晰之方式组织思想,以确凿无疑之论据支持论点,从而得出合乎理性之结论。具体而言,逻辑包括演绎推理与归纳推理两大范畴。演绎推理者,由一般而个别之方法也,自普遍之原则推导出特殊之结论。其过程犹如由已知之因推导出未知之果,步步为营,无懈可击。归纳推理则反其道而行之,由个别而一般,从特殊之事实中提炼出普遍之规律。此二者相辅相成,共同构成了逻辑学之核心。吾辈学者研习逻辑,旨在培养清晰之思维、严谨之态度与准确之表达能力。于个人而言,逻辑使我们能更好地理解世界、把握自身;于社会而言,逻辑则是促进文明进步、维护公平正义之基石。故逻辑之学,实乃人生不可或缺之智慧也。", 105 | "" 106 | ] 107 | } 108 | ], 109 | "links": [ 110 | [ 111 | 1, 112 | 2, 113 | 0, 114 | 3, 115 | 0, 116 | "AUDIO" 117 | ], 118 | [ 119 | 2, 120 | 1, 121 | 0, 122 | 2, 123 | 0, 124 | "AUDIO" 125 | ] 126 | ], 127 | "groups": [], 128 | "config": {}, 129 | "extra": { 130 | "ds": { 131 | "scale": 1.2839025177495025, 132 | "offset": [ 133 | 493.2809123871739, 134 | 820.31454385171 135 | ] 136 | } 137 | }, 138 | "version": 0.4 139 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ntcosyvoice" 3 | description = "ComfyUI_NTCosyVoice is a plugin of ComfyUI for Cosysvoice2" 4 | version = "1.0.0" 5 | license = {file = "LICENSE"} 6 | dependencies = ["conformer==0.3.2", "deepspeed==0.14.2", "diffusers==0.27.2", "gdown==5.1.0", "grpcio==1.57.0", "grpcio-tools==1.57.0", "huggingface-hub==0.25.2", "modelscope", "hydra-core==1.3.2", "HyperPyYAML==1.2.2", "inflect==7.3.1", "librosa==0.10.2", "lightning==2.2.4", "matplotlib==3.7.5", "modelscope==1.15.0", "networkx==3.1", "omegaconf==2.3.0", "openai-whisper", "protobuf==4.25", "pydantic==2.7.0", "rich==13.7.1", "soundfile==0.12.1", "tensorboard", "tensorrt-cu12", "tensorrt-cu12-bindings", "tensorrt-cu12-libs", "transformers==4.40.1", "wget==3.2", "WeTextProcessing==1.0.3", "onnxruntime-gpu", "torch", "torchvision", "torchaudio"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/muxueChen/ComfyUI_NTCosyVoice" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "" 14 | DisplayName = "ComfyUI_NTCosyVoice" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | conformer 2 | deepspeed 3 | diffusers 4 | gdown 5 | modelscope 6 | hydra-core 7 | HyperPyYAML 8 | inflect 9 | librosa 10 | pyarrow 11 | lightning 12 | matplotlib 13 | omegaconf 14 | openai-whisper 15 | rich 16 | soundfile 17 | tensorboard 18 | tensorrt-cu12 19 | tensorrt-cu12-bindings 20 | tensorrt-cu12-libs 21 | wget 22 | WeTextProcessing 23 | onnxruntime-gpu -------------------------------------------------------------------------------- /third_party/Matcha-TTS/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shivam Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include *.cff 5 | include requirements.txt 6 | include matcha/VERSION 7 | recursive-include matcha *.json 8 | recursive-include matcha *.html 9 | recursive-include matcha *.png 10 | recursive-include matcha *.md 11 | recursive-include matcha *.py 12 | recursive-include matcha *.pyx 13 | recursive-exclude tests * 14 | prune tests* 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | create-package: ## Create wheel and tar gz 17 | rm -rf dist/ 18 | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 19 | python setup.py sdist 20 | python -m twine upload dist/* --verbose --skip-existing 21 | 22 | format: ## Run pre-commit hooks 23 | pre-commit run -a 24 | 25 | sync: ## Merge changes from main branch to your current branch 26 | git pull 27 | git pull origin main 28 | 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | test-full: ## Run all tests 33 | pytest 34 | 35 | train-ljspeech: ## Train the model 36 | python matcha/train.py experiment=ljspeech 37 | 38 | train-ljspeech-min: ## Train the model with minimum memory 39 | python matcha/train.py experiment=ljspeech_min_memory 40 | 41 | start_app: ## Start the app 42 | python matcha/app.py 43 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/configs/callbacks/none.yaml -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/data/hi-fi_en-US_female.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | # Dataset URL: https://ast-astrec.nict.go.jp/en/release/hi-fi-captain/ 6 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 7 | name: hi-fi_en-US_female 8 | train_filelist_path: data/filelists/hi-fi-captain-en-us-female_train.txt 9 | valid_filelist_path: data/filelists/hi-fi-captain-en-us-female_val.txt 10 | batch_size: 32 11 | cleaners: [english_cleaners_piper] 12 | data_statistics: # Computed for this dataset 13 | mel_mean: -6.38385 14 | mel_std: 2.541796 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/data/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 2 | name: ljspeech 3 | train_filelist_path: data/filelists/ljs_audio_text_train_filelist.txt 4 | valid_filelist_path: data/filelists/ljs_audio_text_val_filelist.txt 5 | batch_size: 32 6 | num_workers: 20 7 | pin_memory: True 8 | cleaners: [english_cleaners2] 9 | add_blank: True 10 | n_spks: 1 11 | n_fft: 1024 12 | n_feats: 80 13 | sample_rate: 22050 14 | hop_length: 256 15 | win_length: 1024 16 | f_min: 0 17 | f_max: 8000 18 | data_statistics: # Computed for ljspeech dataset 19 | mel_mean: -5.536622 20 | mel_std: 2.116101 21 | seed: ${seed} 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/data/vctk.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - ljspeech 3 | - _self_ 4 | 5 | _target_: matcha.data.text_mel_datamodule.TextMelDataModule 6 | name: vctk 7 | train_filelist_path: data/filelists/vctk_audio_sid_text_train_filelist.txt 8 | valid_filelist_path: data/filelists/vctk_audio_sid_text_val_filelist.txt 9 | batch_size: 32 10 | add_blank: True 11 | n_spks: 109 12 | data_statistics: # Computed for vctk dataset 13 | mel_mean: -6.630575 14 | mel_std: 2.482914 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: hi-fi_en-US_female.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] 13 | 14 | run_name: hi-fi_en-US_female_piper_phonemizer 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/local/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/configs/local/.gitkeep -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/cfm/default.yaml: -------------------------------------------------------------------------------- 1 | name: CFM 2 | solver: euler 3 | sigma_min: 1e-4 4 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | channels: [256, 256] 2 | dropout: 0.05 3 | attention_head_dim: 64 4 | n_blocks: 1 5 | num_mid_blocks: 2 6 | num_heads: 2 7 | act_fn: snakebeta 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: RoPE Encoder 2 | encoder_params: 3 | n_feats: ${model.n_feats} 4 | n_channels: 192 5 | filter_channels: 768 6 | filter_channels_dp: 256 7 | n_heads: 2 8 | n_layers: 6 9 | kernel_size: 3 10 | p_dropout: 0.1 11 | spk_emb_dim: 64 12 | n_spks: 1 13 | prenet: true 14 | 15 | duration_predictor_params: 16 | filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} 17 | kernel_size: 3 18 | p_dropout: ${model.encoder.encoder_params.p_dropout} 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/matcha.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - decoder: default.yaml 5 | - cfm: default.yaml 6 | - optimizer: adam.yaml 7 | 8 | _target_: matcha.models.matcha_tts.MatchaTTS 9 | n_vocab: 178 10 | n_spks: ${data.n_spks} 11 | spk_emb_dim: 64 12 | n_feats: 80 13 | data_statistics: ${data.data_statistics} 14 | out_size: null # Must be divisible by 4 15 | prior_loss: true 16 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: matcha 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 1234 52 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/data: -------------------------------------------------------------------------------- 1 | /home/smehta/Projects/Speech-Backbones/Grad-TTS/data -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.5.1 2 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/data/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/data/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/data/components/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/hifigan/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/meldataset.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import math 4 | import os 5 | import random 6 | 7 | import numpy as np 8 | import torch 9 | import torch.utils.data 10 | from librosa.filters import mel as librosa_mel_fn 11 | from librosa.util import normalize 12 | from scipy.io.wavfile import read 13 | 14 | MAX_WAV_VALUE = 32768.0 15 | 16 | 17 | def load_wav(full_path): 18 | sampling_rate, data = read(full_path) 19 | return data, sampling_rate 20 | 21 | 22 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 23 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 24 | 25 | 26 | def dynamic_range_decompression(x, C=1): 27 | return np.exp(x) / C 28 | 29 | 30 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 31 | return torch.log(torch.clamp(x, min=clip_val) * C) 32 | 33 | 34 | def dynamic_range_decompression_torch(x, C=1): 35 | return torch.exp(x) / C 36 | 37 | 38 | def spectral_normalize_torch(magnitudes): 39 | output = dynamic_range_compression_torch(magnitudes) 40 | return output 41 | 42 | 43 | def spectral_de_normalize_torch(magnitudes): 44 | output = dynamic_range_decompression_torch(magnitudes) 45 | return output 46 | 47 | 48 | mel_basis = {} 49 | hann_window = {} 50 | 51 | 52 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 53 | if torch.min(y) < -1.0: 54 | print("min value is ", torch.min(y)) 55 | if torch.max(y) > 1.0: 56 | print("max value is ", torch.max(y)) 57 | 58 | global mel_basis, hann_window # pylint: disable=global-statement 59 | if fmax not in mel_basis: 60 | mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) 61 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 62 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 63 | 64 | y = torch.nn.functional.pad( 65 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 66 | ) 67 | y = y.squeeze(1) 68 | 69 | spec = torch.view_as_real( 70 | torch.stft( 71 | y, 72 | n_fft, 73 | hop_length=hop_size, 74 | win_length=win_size, 75 | window=hann_window[str(y.device)], 76 | center=center, 77 | pad_mode="reflect", 78 | normalized=False, 79 | onesided=True, 80 | return_complex=True, 81 | ) 82 | ) 83 | 84 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 85 | 86 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 87 | spec = spectral_normalize_torch(spec) 88 | 89 | return spec 90 | 91 | 92 | def get_dataset_filelist(a): 93 | with open(a.input_training_file, encoding="utf-8") as fi: 94 | training_files = [ 95 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 96 | ] 97 | 98 | with open(a.input_validation_file, encoding="utf-8") as fi: 99 | validation_files = [ 100 | os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0 101 | ] 102 | return training_files, validation_files 103 | 104 | 105 | class MelDataset(torch.utils.data.Dataset): 106 | def __init__( 107 | self, 108 | training_files, 109 | segment_size, 110 | n_fft, 111 | num_mels, 112 | hop_size, 113 | win_size, 114 | sampling_rate, 115 | fmin, 116 | fmax, 117 | split=True, 118 | shuffle=True, 119 | n_cache_reuse=1, 120 | device=None, 121 | fmax_loss=None, 122 | fine_tuning=False, 123 | base_mels_path=None, 124 | ): 125 | self.audio_files = training_files 126 | random.seed(1234) 127 | if shuffle: 128 | random.shuffle(self.audio_files) 129 | self.segment_size = segment_size 130 | self.sampling_rate = sampling_rate 131 | self.split = split 132 | self.n_fft = n_fft 133 | self.num_mels = num_mels 134 | self.hop_size = hop_size 135 | self.win_size = win_size 136 | self.fmin = fmin 137 | self.fmax = fmax 138 | self.fmax_loss = fmax_loss 139 | self.cached_wav = None 140 | self.n_cache_reuse = n_cache_reuse 141 | self._cache_ref_count = 0 142 | self.device = device 143 | self.fine_tuning = fine_tuning 144 | self.base_mels_path = base_mels_path 145 | 146 | def __getitem__(self, index): 147 | filename = self.audio_files[index] 148 | if self._cache_ref_count == 0: 149 | audio, sampling_rate = load_wav(filename) 150 | audio = audio / MAX_WAV_VALUE 151 | if not self.fine_tuning: 152 | audio = normalize(audio) * 0.95 153 | self.cached_wav = audio 154 | if sampling_rate != self.sampling_rate: 155 | raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR") 156 | self._cache_ref_count = self.n_cache_reuse 157 | else: 158 | audio = self.cached_wav 159 | self._cache_ref_count -= 1 160 | 161 | audio = torch.FloatTensor(audio) 162 | audio = audio.unsqueeze(0) 163 | 164 | if not self.fine_tuning: 165 | if self.split: 166 | if audio.size(1) >= self.segment_size: 167 | max_audio_start = audio.size(1) - self.segment_size 168 | audio_start = random.randint(0, max_audio_start) 169 | audio = audio[:, audio_start : audio_start + self.segment_size] 170 | else: 171 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 172 | 173 | mel = mel_spectrogram( 174 | audio, 175 | self.n_fft, 176 | self.num_mels, 177 | self.sampling_rate, 178 | self.hop_size, 179 | self.win_size, 180 | self.fmin, 181 | self.fmax, 182 | center=False, 183 | ) 184 | else: 185 | mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy")) 186 | mel = torch.from_numpy(mel) 187 | 188 | if len(mel.shape) < 3: 189 | mel = mel.unsqueeze(0) 190 | 191 | if self.split: 192 | frames_per_seg = math.ceil(self.segment_size / self.hop_size) 193 | 194 | if audio.size(1) >= self.segment_size: 195 | mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1) 196 | mel = mel[:, :, mel_start : mel_start + frames_per_seg] 197 | audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size] 198 | else: 199 | mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant") 200 | audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant") 201 | 202 | mel_loss = mel_spectrogram( 203 | audio, 204 | self.n_fft, 205 | self.num_mels, 206 | self.sampling_rate, 207 | self.hop_size, 208 | self.win_size, 209 | self.fmin, 210 | self.fmax_loss, 211 | center=False, 212 | ) 213 | 214 | return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze()) 215 | 216 | def __len__(self): 217 | return len(self.audio_files) 218 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/models/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/models/components/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from matcha.models.components.decoder import Decoder 7 | from matcha.utils.pylogger import get_pylogger 8 | 9 | log = get_pylogger(__name__) 10 | 11 | 12 | class BASECFM(torch.nn.Module, ABC): 13 | def __init__( 14 | self, 15 | n_feats, 16 | cfm_params, 17 | n_spks=1, 18 | spk_emb_dim=128, 19 | ): 20 | super().__init__() 21 | self.n_feats = n_feats 22 | self.n_spks = n_spks 23 | self.spk_emb_dim = spk_emb_dim 24 | self.solver = cfm_params.solver 25 | if hasattr(cfm_params, "sigma_min"): 26 | self.sigma_min = cfm_params.sigma_min 27 | else: 28 | self.sigma_min = 1e-4 29 | 30 | self.estimator = None 31 | 32 | @torch.inference_mode() 33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 34 | """Forward diffusion 35 | 36 | Args: 37 | mu (torch.Tensor): output of encoder 38 | shape: (batch_size, n_feats, mel_timesteps) 39 | mask (torch.Tensor): output_mask 40 | shape: (batch_size, 1, mel_timesteps) 41 | n_timesteps (int): number of diffusion steps 42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 43 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 44 | shape: (batch_size, spk_emb_dim) 45 | cond: Not used but kept for future purposes 46 | 47 | Returns: 48 | sample: generated mel-spectrogram 49 | shape: (batch_size, n_feats, mel_timesteps) 50 | """ 51 | z = torch.randn_like(mu) * temperature 52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 54 | 55 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 56 | """ 57 | Fixed euler solver for ODEs. 58 | Args: 59 | x (torch.Tensor): random noise 60 | t_span (torch.Tensor): n_timesteps interpolated 61 | shape: (n_timesteps + 1,) 62 | mu (torch.Tensor): output of encoder 63 | shape: (batch_size, n_feats, mel_timesteps) 64 | mask (torch.Tensor): output_mask 65 | shape: (batch_size, 1, mel_timesteps) 66 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 67 | shape: (batch_size, spk_emb_dim) 68 | cond: Not used but kept for future purposes 69 | """ 70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 71 | 72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 73 | # Or in future might add like a return_all_steps flag 74 | sol = [] 75 | 76 | for step in range(1, len(t_span)): 77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 78 | 79 | x = x + dt * dphi_dt 80 | t = t + dt 81 | sol.append(x) 82 | if step < len(t_span) - 1: 83 | dt = t_span[step + 1] - t 84 | 85 | return sol[-1] 86 | 87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 88 | """Computes diffusion loss 89 | 90 | Args: 91 | x1 (torch.Tensor): Target 92 | shape: (batch_size, n_feats, mel_timesteps) 93 | mask (torch.Tensor): target mask 94 | shape: (batch_size, 1, mel_timesteps) 95 | mu (torch.Tensor): output of encoder 96 | shape: (batch_size, n_feats, mel_timesteps) 97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 98 | shape: (batch_size, spk_emb_dim) 99 | 100 | Returns: 101 | loss: conditional flow matching loss 102 | y: conditional flow 103 | shape: (batch_size, n_feats, mel_timesteps) 104 | """ 105 | b, _, t = mu.shape 106 | 107 | # random timestep 108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 109 | # sample noise p(x_0) 110 | z = torch.randn_like(x1) 111 | 112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 113 | u = x1 - (1 - self.sigma_min) * z 114 | 115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 116 | torch.sum(mask) * u.shape[1] 117 | ) 118 | return loss, y 119 | 120 | 121 | class CFM(BASECFM): 122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 123 | super().__init__( 124 | n_feats=in_channels, 125 | cfm_params=cfm_params, 126 | n_spks=n_spks, 127 | spk_emb_dim=spk_emb_dim, 128 | ) 129 | 130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 131 | # Just change the architecture of the estimator here 132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 133 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/onnx/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/matcha/onnx/__init__.py -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/onnx/export.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import torch 7 | from lightning import LightningModule 8 | 9 | from matcha.cli import VOCODER_URLS, load_matcha, load_vocoder 10 | 11 | DEFAULT_OPSET = 15 12 | 13 | SEED = 1234 14 | random.seed(SEED) 15 | np.random.seed(SEED) 16 | torch.manual_seed(SEED) 17 | torch.cuda.manual_seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | class MatchaWithVocoder(LightningModule): 23 | def __init__(self, matcha, vocoder): 24 | super().__init__() 25 | self.matcha = matcha 26 | self.vocoder = vocoder 27 | 28 | def forward(self, x, x_lengths, scales, spks=None): 29 | mel, mel_lengths = self.matcha(x, x_lengths, scales, spks) 30 | wavs = self.vocoder(mel).clamp(-1, 1) 31 | lengths = mel_lengths * 256 32 | return wavs.squeeze(1), lengths 33 | 34 | 35 | def get_exportable_module(matcha, vocoder, n_timesteps): 36 | """ 37 | Return an appropriate `LighteningModule` and output-node names 38 | based on whether the vocoder is embedded in the final graph 39 | """ 40 | 41 | def onnx_forward_func(x, x_lengths, scales, spks=None): 42 | """ 43 | Custom forward function for accepting 44 | scaler parameters as tensors 45 | """ 46 | # Extract scaler parameters from tensors 47 | temperature = scales[0] 48 | length_scale = scales[1] 49 | output = matcha.synthesise(x, x_lengths, n_timesteps, temperature, spks, length_scale) 50 | return output["mel"], output["mel_lengths"] 51 | 52 | # Monkey-patch Matcha's forward function 53 | matcha.forward = onnx_forward_func 54 | 55 | if vocoder is None: 56 | model, output_names = matcha, ["mel", "mel_lengths"] 57 | else: 58 | model = MatchaWithVocoder(matcha, vocoder) 59 | output_names = ["wav", "wav_lengths"] 60 | return model, output_names 61 | 62 | 63 | def get_inputs(is_multi_speaker): 64 | """ 65 | Create dummy inputs for tracing 66 | """ 67 | dummy_input_length = 50 68 | x = torch.randint(low=0, high=20, size=(1, dummy_input_length), dtype=torch.long) 69 | x_lengths = torch.LongTensor([dummy_input_length]) 70 | 71 | # Scales 72 | temperature = 0.667 73 | length_scale = 1.0 74 | scales = torch.Tensor([temperature, length_scale]) 75 | 76 | model_inputs = [x, x_lengths, scales] 77 | input_names = [ 78 | "x", 79 | "x_lengths", 80 | "scales", 81 | ] 82 | 83 | if is_multi_speaker: 84 | spks = torch.LongTensor([1]) 85 | model_inputs.append(spks) 86 | input_names.append("spks") 87 | 88 | return tuple(model_inputs), input_names 89 | 90 | 91 | def main(): 92 | parser = argparse.ArgumentParser(description="Export 🍵 Matcha-TTS to ONNX") 93 | 94 | parser.add_argument( 95 | "checkpoint_path", 96 | type=str, 97 | help="Path to the model checkpoint", 98 | ) 99 | parser.add_argument("output", type=str, help="Path to output `.onnx` file") 100 | parser.add_argument( 101 | "--n-timesteps", type=int, default=5, help="Number of steps to use for reverse diffusion in decoder (default 5)" 102 | ) 103 | parser.add_argument( 104 | "--vocoder-name", 105 | type=str, 106 | choices=list(VOCODER_URLS.keys()), 107 | default=None, 108 | help="Name of the vocoder to embed in the ONNX graph", 109 | ) 110 | parser.add_argument( 111 | "--vocoder-checkpoint-path", 112 | type=str, 113 | default=None, 114 | help="Vocoder checkpoint to embed in the ONNX graph for an `e2e` like experience", 115 | ) 116 | parser.add_argument("--opset", type=int, default=DEFAULT_OPSET, help="ONNX opset version to use (default 15") 117 | 118 | args = parser.parse_args() 119 | 120 | print(f"[🍵] Loading Matcha checkpoint from {args.checkpoint_path}") 121 | print(f"Setting n_timesteps to {args.n_timesteps}") 122 | 123 | checkpoint_path = Path(args.checkpoint_path) 124 | matcha = load_matcha(checkpoint_path.stem, checkpoint_path, "cpu") 125 | 126 | if args.vocoder_name or args.vocoder_checkpoint_path: 127 | assert ( 128 | args.vocoder_name and args.vocoder_checkpoint_path 129 | ), "Both vocoder_name and vocoder-checkpoint are required when embedding the vocoder in the ONNX graph." 130 | vocoder, _ = load_vocoder(args.vocoder_name, args.vocoder_checkpoint_path, "cpu") 131 | else: 132 | vocoder = None 133 | 134 | is_multi_speaker = matcha.n_spks > 1 135 | 136 | dummy_input, input_names = get_inputs(is_multi_speaker) 137 | model, output_names = get_exportable_module(matcha, vocoder, args.n_timesteps) 138 | 139 | # Set dynamic shape for inputs/outputs 140 | dynamic_axes = { 141 | "x": {0: "batch_size", 1: "time"}, 142 | "x_lengths": {0: "batch_size"}, 143 | } 144 | 145 | if vocoder is None: 146 | dynamic_axes.update( 147 | { 148 | "mel": {0: "batch_size", 2: "time"}, 149 | "mel_lengths": {0: "batch_size"}, 150 | } 151 | ) 152 | else: 153 | print("Embedding the vocoder in the ONNX graph") 154 | dynamic_axes.update( 155 | { 156 | "wav": {0: "batch_size", 1: "time"}, 157 | "wav_lengths": {0: "batch_size"}, 158 | } 159 | ) 160 | 161 | if is_multi_speaker: 162 | dynamic_axes["spks"] = {0: "batch_size"} 163 | 164 | # Create the output directory (if not exists) 165 | Path(args.output).parent.mkdir(parents=True, exist_ok=True) 166 | 167 | model.to_onnx( 168 | args.output, 169 | dummy_input, 170 | input_names=input_names, 171 | output_names=output_names, 172 | dynamic_axes=dynamic_axes, 173 | opset_version=args.opset, 174 | export_params=True, 175 | do_constant_folding=True, 176 | ) 177 | print(f"[🍵] ONNX model exported to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/onnx/infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from pathlib import Path 5 | from time import perf_counter 6 | 7 | import numpy as np 8 | import onnxruntime as ort 9 | import soundfile as sf 10 | import torch 11 | 12 | from matcha.cli import plot_spectrogram_to_numpy, process_text 13 | 14 | 15 | def validate_args(args): 16 | assert ( 17 | args.text or args.file 18 | ), "Either text or file must be provided Matcha-T(ea)TTS need sometext to whisk the waveforms." 19 | assert args.temperature >= 0, "Sampling temperature cannot be negative" 20 | assert args.speaking_rate >= 0, "Speaking rate must be greater than 0" 21 | return args 22 | 23 | 24 | def write_wavs(model, inputs, output_dir, external_vocoder=None): 25 | if external_vocoder is None: 26 | print("The provided model has the vocoder embedded in the graph.\nGenerating waveform directly") 27 | t0 = perf_counter() 28 | wavs, wav_lengths = model.run(None, inputs) 29 | infer_secs = perf_counter() - t0 30 | mel_infer_secs = vocoder_infer_secs = None 31 | else: 32 | print("[🍵] Generating mel using Matcha") 33 | mel_t0 = perf_counter() 34 | mels, mel_lengths = model.run(None, inputs) 35 | mel_infer_secs = perf_counter() - mel_t0 36 | print("Generating waveform from mel using external vocoder") 37 | vocoder_inputs = {external_vocoder.get_inputs()[0].name: mels} 38 | vocoder_t0 = perf_counter() 39 | wavs = external_vocoder.run(None, vocoder_inputs)[0] 40 | vocoder_infer_secs = perf_counter() - vocoder_t0 41 | wavs = wavs.squeeze(1) 42 | wav_lengths = mel_lengths * 256 43 | infer_secs = mel_infer_secs + vocoder_infer_secs 44 | 45 | output_dir = Path(output_dir) 46 | output_dir.mkdir(parents=True, exist_ok=True) 47 | for i, (wav, wav_length) in enumerate(zip(wavs, wav_lengths)): 48 | output_filename = output_dir.joinpath(f"output_{i + 1}.wav") 49 | audio = wav[:wav_length] 50 | print(f"Writing audio to {output_filename}") 51 | sf.write(output_filename, audio, 22050, "PCM_24") 52 | 53 | wav_secs = wav_lengths.sum() / 22050 54 | print(f"Inference seconds: {infer_secs}") 55 | print(f"Generated wav seconds: {wav_secs}") 56 | rtf = infer_secs / wav_secs 57 | if mel_infer_secs is not None: 58 | mel_rtf = mel_infer_secs / wav_secs 59 | print(f"Matcha RTF: {mel_rtf}") 60 | if vocoder_infer_secs is not None: 61 | vocoder_rtf = vocoder_infer_secs / wav_secs 62 | print(f"Vocoder RTF: {vocoder_rtf}") 63 | print(f"Overall RTF: {rtf}") 64 | 65 | 66 | def write_mels(model, inputs, output_dir): 67 | t0 = perf_counter() 68 | mels, mel_lengths = model.run(None, inputs) 69 | infer_secs = perf_counter() - t0 70 | 71 | output_dir = Path(output_dir) 72 | output_dir.mkdir(parents=True, exist_ok=True) 73 | for i, mel in enumerate(mels): 74 | output_stem = output_dir.joinpath(f"output_{i + 1}") 75 | plot_spectrogram_to_numpy(mel.squeeze(), output_stem.with_suffix(".png")) 76 | np.save(output_stem.with_suffix(".numpy"), mel) 77 | 78 | wav_secs = (mel_lengths * 256).sum() / 22050 79 | print(f"Inference seconds: {infer_secs}") 80 | print(f"Generated wav seconds: {wav_secs}") 81 | rtf = infer_secs / wav_secs 82 | print(f"RTF: {rtf}") 83 | 84 | 85 | def main(): 86 | parser = argparse.ArgumentParser( 87 | description=" 🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching" 88 | ) 89 | parser.add_argument( 90 | "model", 91 | type=str, 92 | help="ONNX model to use", 93 | ) 94 | parser.add_argument("--vocoder", type=str, default=None, help="Vocoder to use (defaults to None)") 95 | parser.add_argument("--text", type=str, default=None, help="Text to synthesize") 96 | parser.add_argument("--file", type=str, default=None, help="Text file to synthesize") 97 | parser.add_argument("--spk", type=int, default=None, help="Speaker ID") 98 | parser.add_argument( 99 | "--temperature", 100 | type=float, 101 | default=0.667, 102 | help="Variance of the x0 noise (default: 0.667)", 103 | ) 104 | parser.add_argument( 105 | "--speaking-rate", 106 | type=float, 107 | default=1.0, 108 | help="change the speaking rate, a higher value means slower speaking rate (default: 1.0)", 109 | ) 110 | parser.add_argument("--gpu", action="store_true", help="Use CPU for inference (default: use GPU if available)") 111 | parser.add_argument( 112 | "--output-dir", 113 | type=str, 114 | default=os.getcwd(), 115 | help="Output folder to save results (default: current dir)", 116 | ) 117 | 118 | args = parser.parse_args() 119 | args = validate_args(args) 120 | 121 | if args.gpu: 122 | providers = ["GPUExecutionProvider"] 123 | else: 124 | providers = ["CPUExecutionProvider"] 125 | model = ort.InferenceSession(args.model, providers=providers) 126 | 127 | model_inputs = model.get_inputs() 128 | model_outputs = list(model.get_outputs()) 129 | 130 | if args.text: 131 | text_lines = args.text.splitlines() 132 | else: 133 | with open(args.file, encoding="utf-8") as file: 134 | text_lines = file.read().splitlines() 135 | 136 | processed_lines = [process_text(0, line, "cpu") for line in text_lines] 137 | x = [line["x"].squeeze() for line in processed_lines] 138 | # Pad 139 | x = torch.nn.utils.rnn.pad_sequence(x, batch_first=True) 140 | x = x.detach().cpu().numpy() 141 | x_lengths = np.array([line["x_lengths"].item() for line in processed_lines], dtype=np.int64) 142 | inputs = { 143 | "x": x, 144 | "x_lengths": x_lengths, 145 | "scales": np.array([args.temperature, args.speaking_rate], dtype=np.float32), 146 | } 147 | is_multi_speaker = len(model_inputs) == 4 148 | if is_multi_speaker: 149 | if args.spk is None: 150 | args.spk = 0 151 | warn = "[!] Speaker ID not provided! Using speaker ID 0" 152 | warnings.warn(warn, UserWarning) 153 | inputs["spks"] = np.repeat(args.spk, x.shape[0]).astype(np.int64) 154 | 155 | has_vocoder_embedded = model_outputs[0].name == "wav" 156 | if has_vocoder_embedded: 157 | write_wavs(model, inputs, args.output_dir) 158 | elif args.vocoder: 159 | external_vocoder = ort.InferenceSession(args.vocoder, providers=providers) 160 | write_wavs(model, inputs, args.output_dir, external_vocoder=external_vocoder) 161 | else: 162 | warn = "[!] A vocoder is not embedded in the graph nor an external vocoder is provided. The mel output will be written as numpy arrays to `*.npy` files in the output directory" 163 | warnings.warn(warn, UserWarning) 164 | write_mels(model, inputs, args.output_dir) 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from matcha.text import cleaners 3 | from matcha.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | import piper_phonemize 19 | from unidecode import unidecode 20 | 21 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 22 | critical_logger = logging.getLogger("phonemizer") 23 | critical_logger.setLevel(logging.CRITICAL) 24 | 25 | # Intializing the phonemizer globally significantly reduces the speed 26 | # now the phonemizer is not initialising at every call 27 | # Might be less flexible, but it is much-much faster 28 | global_phonemizer = phonemizer.backend.EspeakBackend( 29 | language="en-us", 30 | preserve_punctuation=True, 31 | with_stress=True, 32 | language_switch="remove-flags", 33 | logger=critical_logger, 34 | ) 35 | 36 | 37 | # Regular expression matching whitespace: 38 | _whitespace_re = re.compile(r"\s+") 39 | 40 | # List of (regular expression, replacement) pairs for abbreviations: 41 | _abbreviations = [ 42 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 43 | for x in [ 44 | ("mrs", "misess"), 45 | ("mr", "mister"), 46 | ("dr", "doctor"), 47 | ("st", "saint"), 48 | ("co", "company"), 49 | ("jr", "junior"), 50 | ("maj", "major"), 51 | ("gen", "general"), 52 | ("drs", "doctors"), 53 | ("rev", "reverend"), 54 | ("lt", "lieutenant"), 55 | ("hon", "honorable"), 56 | ("sgt", "sergeant"), 57 | ("capt", "captain"), 58 | ("esq", "esquire"), 59 | ("ltd", "limited"), 60 | ("col", "colonel"), 61 | ("ft", "fort"), 62 | ] 63 | ] 64 | 65 | 66 | def expand_abbreviations(text): 67 | for regex, replacement in _abbreviations: 68 | text = re.sub(regex, replacement, text) 69 | return text 70 | 71 | 72 | def lowercase(text): 73 | return text.lower() 74 | 75 | 76 | def collapse_whitespace(text): 77 | return re.sub(_whitespace_re, " ", text) 78 | 79 | 80 | def convert_to_ascii(text): 81 | return unidecode(text) 82 | 83 | 84 | def basic_cleaners(text): 85 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 86 | text = lowercase(text) 87 | text = collapse_whitespace(text) 88 | return text 89 | 90 | 91 | def transliteration_cleaners(text): 92 | """Pipeline for non-English text that transliterates to ASCII.""" 93 | text = convert_to_ascii(text) 94 | text = lowercase(text) 95 | text = collapse_whitespace(text) 96 | return text 97 | 98 | 99 | def english_cleaners2(text): 100 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 101 | text = convert_to_ascii(text) 102 | text = lowercase(text) 103 | text = expand_abbreviations(text) 104 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 105 | phonemes = collapse_whitespace(phonemes) 106 | return phonemes 107 | 108 | 109 | def english_cleaners_piper(text): 110 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 111 | text = convert_to_ascii(text) 112 | text = lowercase(text) 113 | text = expand_abbreviations(text) 114 | phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) 115 | phonemes = collapse_whitespace(phonemes) 116 | return phonemes 117 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from matcha import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from matcha.utils.logging_utils import log_hyperparameters 3 | from matcha.utils.pylogger import get_pylogger 4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree 5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import rootutils 14 | import torch 15 | from hydra import compose, initialize 16 | from omegaconf import open_dict 17 | from tqdm.auto import tqdm 18 | 19 | from matcha.data.text_mel_datamodule import TextMelDataModule 20 | from matcha.utils.logging_utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 26 | """Generate data mean and standard deviation helpful in data normalisation 27 | 28 | Args: 29 | data_loader (torch.utils.data.Dataloader): _description_ 30 | out_channels (int): mel spectrogram channels 31 | """ 32 | total_mel_sum = 0 33 | total_mel_sq_sum = 0 34 | total_mel_len = 0 35 | 36 | for batch in tqdm(data_loader, leave=False): 37 | mels = batch["y"] 38 | mel_lengths = batch["y_lengths"] 39 | 40 | total_mel_len += torch.sum(mel_lengths) 41 | total_mel_sum += torch.sum(mels) 42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 43 | 44 | data_mean = total_mel_sum / (total_mel_len * out_channels) 45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 46 | 47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument( 54 | "-i", 55 | "--input-config", 56 | type=str, 57 | default="vctk.yaml", 58 | help="The name of the yaml config file under configs/data", 59 | ) 60 | 61 | parser.add_argument( 62 | "-b", 63 | "--batch-size", 64 | type=int, 65 | default="256", 66 | help="Can have increased batch size for faster computation", 67 | ) 68 | 69 | parser.add_argument( 70 | "-f", 71 | "--force", 72 | action="store_true", 73 | default=False, 74 | required=False, 75 | help="force overwrite the file", 76 | ) 77 | args = parser.parse_args() 78 | output_file = Path(args.input_config).with_suffix(".json") 79 | 80 | if os.path.exists(output_file) and not args.force: 81 | print("File already exists. Use -f to force overwrite") 82 | sys.exit(1) 83 | 84 | with initialize(version_base="1.3", config_path="../../configs/data"): 85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 86 | 87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 88 | 89 | with open_dict(cfg): 90 | del cfg["hydra"] 91 | del cfg["_target_"] 92 | cfg["data_statistics"] = None 93 | cfg["seed"] = 1234 94 | cfg["batch_size"] = args.batch_size 95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 97 | 98 | text_mel_datamodule = TextMelDataModule(**cfg) 99 | text_mel_datamodule.setup() 100 | data_loader = text_mel_datamodule.train_dataloader() 101 | log.info("Dataloader loaded! Now computing stats...") 102 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 103 | print(params) 104 | json.dump( 105 | params, 106 | open(output_file, "w"), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from matcha.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from matcha.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths): 45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 46 | return loss 47 | 48 | 49 | def normalize(data, mu, std): 50 | if not isinstance(mu, (float, int)): 51 | if isinstance(mu, list): 52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 53 | elif isinstance(mu, torch.Tensor): 54 | mu = mu.to(data.device) 55 | elif isinstance(mu, np.ndarray): 56 | mu = torch.from_numpy(mu).to(data.device) 57 | mu = mu.unsqueeze(-1) 58 | 59 | if not isinstance(std, (float, int)): 60 | if isinstance(std, list): 61 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 62 | elif isinstance(std, torch.Tensor): 63 | std = std.to(data.device) 64 | elif isinstance(std, np.ndarray): 65 | std = torch.from_numpy(std).to(data.device) 66 | std = std.unsqueeze(-1) 67 | 68 | return (data - mu) / std 69 | 70 | 71 | def denormalize(data, mu, std): 72 | if not isinstance(mu, float): 73 | if isinstance(mu, list): 74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 75 | elif isinstance(mu, torch.Tensor): 76 | mu = mu.to(data.device) 77 | elif isinstance(mu, np.ndarray): 78 | mu = torch.from_numpy(mu).to(data.device) 79 | mu = mu.unsqueeze(-1) 80 | 81 | if not isinstance(std, float): 82 | if isinstance(std, list): 83 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 84 | elif isinstance(std, torch.Tensor): 85 | std = std.to(data.device) 86 | elif isinstance(std, np.ndarray): 87 | std = torch.from_numpy(std).to(data.device) 88 | std = std.unsqueeze(-1) 89 | 90 | return data * std + mu 91 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matcha.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from matcha.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/notebooks/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/Matcha-TTS/notebooks/.gitkeep -------------------------------------------------------------------------------- /third_party/Matcha-TTS/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ['py310'] 7 | exclude = ''' 8 | 9 | ( 10 | /( 11 | \.eggs # exclude a few common directories in the 12 | | \.git # root of the project 13 | | \.hg 14 | | \.mypy_cache 15 | | \.tox 16 | | \.venv 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | | foo.py # also separately exclude a file named foo.py in 23 | # the root of the project 24 | ) 25 | ''' 26 | 27 | [tool.pytest.ini_options] 28 | addopts = [ 29 | "--color=yes", 30 | "--durations=0", 31 | "--strict-markers", 32 | "--doctest-modules", 33 | ] 34 | filterwarnings = [ 35 | "ignore::DeprecationWarning", 36 | "ignore::UserWarning", 37 | ] 38 | log_cli = "True" 39 | markers = [ 40 | "slow: slow tests", 41 | ] 42 | minversion = "6.0" 43 | testpaths = "tests/" 44 | 45 | [tool.coverage.report] 46 | exclude_lines = [ 47 | "pragma: nocover", 48 | "raise NotImplementedError", 49 | "raise NotImplementedError()", 50 | "if __name__ == .__main__.:", 51 | ] 52 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 -------------------------------------------------------------------------------- /third_party/Matcha-TTS/scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | from Cython.Build import cythonize 6 | from setuptools import Extension, find_packages, setup 7 | 8 | exts = [ 9 | Extension( 10 | name="matcha.utils.monotonic_align.core", 11 | sources=["matcha/utils/monotonic_align/core.pyx"], 12 | ) 13 | ] 14 | 15 | with open("README.md", encoding="utf-8") as readme_file: 16 | README = readme_file.read() 17 | 18 | cwd = os.path.dirname(os.path.abspath(__file__)) 19 | with open(os.path.join(cwd, "matcha", "VERSION")) as fin: 20 | version = fin.read().strip() 21 | 22 | setup( 23 | name="matcha-tts", 24 | version=version, 25 | description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", 26 | long_description=README, 27 | long_description_content_type="text/markdown", 28 | include_dirs=[numpy.get_include()], 29 | include_package_data=True, 30 | packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), 31 | # use this to customize global commands available in the terminal after installing the package 32 | entry_points={ 33 | "console_scripts": [ 34 | "matcha-data-stats=matcha.utils.generate_data_statistics:main", 35 | "matcha-tts=matcha.cli:cli", 36 | "matcha-tts-app=matcha.app:main", 37 | ] 38 | }, 39 | ext_modules=cythonize(exts, language_level=3), 40 | python_requires=">=3.9.0", 41 | ) 42 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/muxueChen/ComfyUI_NTCosyVoice/f5e55835df4e0038c9b87e23cbca6fbe033c0915/third_party/__init__.py --------------------------------------------------------------------------------