├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── lint.yml │ └── stale-issues.yml ├── .gitignore ├── .gitmodules ├── CODE_OF_CONDUCT.md ├── FAQ.md ├── LICENSE ├── README.md ├── asset ├── cross_lingual_prompt.wav ├── dingding.png └── zero_shot_prompt.wav ├── cosyvoice ├── __init__.py ├── bin │ ├── average_model.py │ ├── export_jit.py │ ├── export_onnx.py │ ├── inference.py │ ├── train.py │ └── train_dpo.py ├── cli │ ├── __init__.py │ ├── cosyvoice.py │ ├── frontend.py │ └── model.py ├── dataset │ ├── __init__.py │ ├── dataset.py │ ├── processor.py │ └── processor_dpo.py ├── flow │ ├── decoder.py │ ├── flow.py │ ├── flow_matching.py │ └── length_regulator.py ├── hifigan │ ├── discriminator.py │ ├── f0_predictor.py │ ├── generator.py │ └── hifigan.py ├── llm │ ├── llm.py │ └── llm_dpo.py ├── tokenizer │ ├── assets │ │ └── multilingual_zh_ja_yue_char_del.tiktoken │ └── tokenizer.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── attention.py │ ├── convolution.py │ ├── decoder.py │ ├── decoder_layer.py │ ├── embedding.py │ ├── encoder.py │ ├── encoder_layer.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── subsampling.py │ └── upsample_encoder.py ├── utils │ ├── __init__.py │ ├── class_utils.py │ ├── common.py │ ├── executor.py │ ├── executor_dpo.py │ ├── file_utils.py │ ├── frontend_utils.py │ ├── losses.py │ ├── losses_dpo.py │ ├── mask.py │ ├── scheduler.py │ ├── train_utils.py │ └── train_utils_dpo.py └── vllm │ └── cosyvoice2.py ├── docker └── Dockerfile ├── examples ├── libritts │ ├── cosyvoice │ │ ├── conf │ │ │ ├── cosyvoice.fromscratch.yaml │ │ │ ├── cosyvoice.yaml │ │ │ ├── cosyvoice_dpo.yaml │ │ │ └── ds_stage2.json │ │ ├── cosyvoice │ │ ├── local │ │ │ ├── download_and_untar.sh │ │ │ └── prepare_data.py │ │ ├── path.sh │ │ ├── run.sh │ │ ├── tools │ │ └── tts_text.json │ └── cosyvoice2 │ │ ├── conf │ │ ├── cosyvoice2.yaml │ │ └── ds_stage2.json │ │ ├── cosyvoice │ │ ├── local │ │ ├── path.sh │ │ ├── run.sh │ │ ├── tools │ │ └── tts_text.json └── magicdata-read │ └── cosyvoice │ ├── conf │ ├── cosyvoice │ ├── local │ ├── download_and_untar.sh │ └── prepare_data.py │ ├── path.sh │ ├── run.sh │ ├── tools │ └── tts_text.json ├── requirements.txt ├── runtime └── python │ ├── Dockerfile │ ├── fastapi │ ├── client.py │ └── server.py │ └── grpc │ ├── client.py │ ├── cosyvoice.proto │ └── server.py ├── tools ├── extract_embedding.py ├── extract_speech_token.py ├── make_parquet_list.py └── make_parquet_list_dpo.py ├── vllm_example.py └── webui.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | quick-checks: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Fetch CosyVoice 12 | uses: actions/checkout@v1 13 | - name: Checkout PR tip 14 | run: | 15 | set -eux 16 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then 17 | # We are on a PR, so actions/checkout leaves us on a merge commit. 18 | # Check out the actual tip of the branch. 19 | git checkout ${{ github.event.pull_request.head.sha }} 20 | fi 21 | echo ::set-output name=commit_sha::$(git rev-parse HEAD) 22 | id: get_pr_tip 23 | - name: Ensure no tabs 24 | run: | 25 | (! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false)) 26 | - name: Ensure no trailing whitespace 27 | run: | 28 | (! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false)) 29 | 30 | flake8-py3: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - name: Setup Python 34 | uses: actions/setup-python@v1 35 | with: 36 | python-version: 3.9 37 | architecture: x64 38 | - name: Fetch CosyVoice 39 | uses: actions/checkout@v1 40 | - name: Checkout PR tip 41 | run: | 42 | set -eux 43 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then 44 | # We are on a PR, so actions/checkout leaves us on a merge commit. 45 | # Check out the actual tip of the branch. 46 | git checkout ${{ github.event.pull_request.head.sha }} 47 | fi 48 | echo ::set-output name=commit_sha::$(git rev-parse HEAD) 49 | id: get_pr_tip 50 | - name: Run flake8 51 | run: | 52 | set -eux 53 | pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 54 | flake8 --version 55 | flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py 56 | if [ $? != 0 ]; then exit 1; fi -------------------------------------------------------------------------------- /.github/workflows/stale-issues.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Visual Studio Code files 7 | .vscode 8 | .vs 9 | 10 | # PyCharm files 11 | .idea 12 | 13 | # Eclipse Project settings 14 | *.*project 15 | .settings 16 | 17 | # Sublime Text settings 18 | *.sublime-workspace 19 | *.sublime-project 20 | 21 | # Editor temporaries 22 | *.swn 23 | *.swo 24 | *.swp 25 | *.swm 26 | *~ 27 | 28 | # IPython notebook checkpoints 29 | .ipynb_checkpoints 30 | 31 | # macOS dir files 32 | .DS_Store 33 | 34 | exp 35 | data 36 | raw_wav 37 | tensorboard 38 | **/*build* 39 | 40 | # Clangd files 41 | .cache 42 | compile_commands.json 43 | 44 | # train/inference files 45 | *.wav 46 | *.m4a 47 | *.aac 48 | *.pt 49 | pretrained_models/* 50 | *_pb2_grpc.py 51 | *_pb2.py 52 | *.tar -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Matcha-TTS"] 2 | path = third_party/Matcha-TTS 3 | url = https://github.com/shivammehta25/Matcha-TTS.git 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at mikelei@mobvoi.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## ModuleNotFoundError: No module named 'matcha' 2 | 3 | Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`. 4 | 5 | run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script. 6 | 7 | ## cannot find resource.zip or cannot unzip resource.zip 8 | 9 | Please make sure you have git-lfs installed. Execute 10 | 11 | ```sh 12 | git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd 13 | cd pretrained_models/CosyVoice-ttsfrd/ 14 | unzip resource.zip -d . 15 | pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl 16 | ``` 17 | -------------------------------------------------------------------------------- /asset/cross_lingual_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/cross_lingual_prompt.wav -------------------------------------------------------------------------------- /asset/dingding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/dingding.png -------------------------------------------------------------------------------- /asset/zero_shot_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/zero_shot_prompt.wav -------------------------------------------------------------------------------- /cosyvoice/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/__init__.py -------------------------------------------------------------------------------- /cosyvoice/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import glob 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser(description='average model') 26 | parser.add_argument('--dst_model', required=True, help='averaged model') 27 | parser.add_argument('--src_path', 28 | required=True, 29 | help='src model path for average') 30 | parser.add_argument('--val_best', 31 | action="store_true", 32 | help='averaged model') 33 | parser.add_argument('--num', 34 | default=5, 35 | type=int, 36 | help='nums for averaged model') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | return args 41 | 42 | 43 | def main(): 44 | args = get_args() 45 | val_scores = [] 46 | if args.val_best: 47 | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) 48 | yamls = [ 49 | f for f in yamls 50 | if not (os.path.basename(f).startswith('train') 51 | or os.path.basename(f).startswith('init')) 52 | ] 53 | for y in yamls: 54 | with open(y, 'r') as f: 55 | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) 56 | loss = float(dic_yaml['loss_dict']['loss']) 57 | epoch = int(dic_yaml['epoch']) 58 | step = int(dic_yaml['step']) 59 | tag = dic_yaml['tag'] 60 | val_scores += [[epoch, step, loss, tag]] 61 | sorted_val_scores = sorted(val_scores, 62 | key=lambda x: x[2], 63 | reverse=False) 64 | print("best val (epoch, step, loss, tag) = " + 65 | str(sorted_val_scores[:args.num])) 66 | path_list = [ 67 | args.src_path + '/epoch_{}_whole.pt'.format(score[0]) 68 | for score in sorted_val_scores[:args.num] 69 | ] 70 | print(path_list) 71 | avg = {} 72 | num = args.num 73 | assert num == len(path_list) 74 | for path in path_list: 75 | print('Processing {}'.format(path)) 76 | states = torch.load(path, map_location=torch.device('cpu')) 77 | for k in states.keys(): 78 | if k not in ['step', 'epoch']: 79 | if k not in avg.keys(): 80 | avg[k] = states[k].clone() 81 | else: 82 | avg[k] += states[k] 83 | # average 84 | for k in avg.keys(): 85 | if avg[k] is not None: 86 | # pytorch 1.6 use true_divide instead of /= 87 | avg[k] = torch.true_divide(avg[k], num) 88 | print('Saving to {}'.format(args.dst_model)) 89 | torch.save(avg, args.dst_model) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 27 | from cosyvoice.utils.file_utils import logging 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='export your model for deployment') 32 | parser.add_argument('--model_dir', 33 | type=str, 34 | default='pretrained_models/CosyVoice-300M', 35 | help='local path') 36 | args = parser.parse_args() 37 | print(args) 38 | return args 39 | 40 | 41 | def get_optimized_script(model, preserved_attrs=[]): 42 | script = torch.jit.script(model) 43 | if preserved_attrs != []: 44 | script = torch.jit.freeze(script, preserved_attrs=preserved_attrs) 45 | else: 46 | script = torch.jit.freeze(script) 47 | script = torch.jit.optimize_for_inference(script) 48 | return script 49 | 50 | 51 | def main(): 52 | args = get_args() 53 | logging.basicConfig(level=logging.DEBUG, 54 | format='%(asctime)s %(levelname)s %(message)s') 55 | 56 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 57 | torch._C._jit_set_profiling_mode(False) 58 | torch._C._jit_set_profiling_executor(False) 59 | 60 | try: 61 | model = CosyVoice(args.model_dir) 62 | except Exception: 63 | try: 64 | model = CosyVoice2(args.model_dir) 65 | except Exception: 66 | raise TypeError('no valid model_type!') 67 | 68 | if not isinstance(model, CosyVoice2): 69 | # 1. export llm text_encoder 70 | llm_text_encoder = model.model.llm.text_encoder 71 | script = get_optimized_script(llm_text_encoder) 72 | script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) 73 | script = get_optimized_script(llm_text_encoder.half()) 74 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 75 | logging.info('successfully export llm_text_encoder') 76 | 77 | # 2. export llm llm 78 | llm_llm = model.model.llm.llm 79 | script = get_optimized_script(llm_llm, ['forward_chunk']) 80 | script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) 81 | script = get_optimized_script(llm_llm.half(), ['forward_chunk']) 82 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 83 | logging.info('successfully export llm_llm') 84 | 85 | # 3. export flow encoder 86 | flow_encoder = model.model.flow.encoder 87 | script = get_optimized_script(flow_encoder) 88 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 89 | script = get_optimized_script(flow_encoder.half()) 90 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 91 | logging.info('successfully export flow_encoder') 92 | else: 93 | # 3. export flow encoder 94 | flow_encoder = model.model.flow.encoder 95 | script = get_optimized_script(flow_encoder) 96 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 97 | script = get_optimized_script(flow_encoder.half()) 98 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 99 | logging.info('successfully export flow_encoder') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 31 | from cosyvoice.utils.file_utils import logging 32 | 33 | 34 | def get_dummy_input(batch_size, seq_len, out_channels, device): 35 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 36 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 37 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 38 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 39 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 40 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 41 | return x, mask, mu, t, spks, cond 42 | 43 | 44 | def get_args(): 45 | parser = argparse.ArgumentParser(description='export your model for deployment') 46 | parser.add_argument('--model_dir', 47 | type=str, 48 | default='pretrained_models/CosyVoice-300M', 49 | help='local path') 50 | args = parser.parse_args() 51 | print(args) 52 | return args 53 | 54 | 55 | @torch.no_grad() 56 | def main(): 57 | args = get_args() 58 | logging.basicConfig(level=logging.DEBUG, 59 | format='%(asctime)s %(levelname)s %(message)s') 60 | 61 | try: 62 | model = CosyVoice(args.model_dir) 63 | except Exception: 64 | try: 65 | model = CosyVoice2(args.model_dir) 66 | except Exception: 67 | raise TypeError('no valid model_type!') 68 | 69 | # 1. export flow decoder estimator 70 | estimator = model.model.flow.decoder.estimator 71 | estimator.eval() 72 | 73 | device = model.model.device 74 | batch_size, seq_len = 2, 256 75 | out_channels = model.model.flow.decoder.estimator.out_channels 76 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 77 | torch.onnx.export( 78 | estimator, 79 | (x, mask, mu, t, spks, cond), 80 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 81 | export_params=True, 82 | opset_version=18, 83 | do_constant_folding=True, 84 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 85 | output_names=['estimator_out'], 86 | dynamic_axes={ 87 | 'x': {2: 'seq_len'}, 88 | 'mask': {2: 'seq_len'}, 89 | 'mu': {2: 'seq_len'}, 90 | 'cond': {2: 'seq_len'}, 91 | 'estimator_out': {2: 'seq_len'}, 92 | } 93 | ) 94 | 95 | # 2. test computation consistency 96 | option = onnxruntime.SessionOptions() 97 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 98 | option.intra_op_num_threads = 1 99 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 100 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 101 | sess_options=option, providers=providers) 102 | 103 | for _ in tqdm(range(10)): 104 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) 105 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 106 | ort_inputs = { 107 | 'x': x.cpu().numpy(), 108 | 'mask': mask.cpu().numpy(), 109 | 'mu': mu.cpu().numpy(), 110 | 't': t.cpu().numpy(), 111 | 'spks': spks.cpu().numpy(), 112 | 'cond': cond.cpu().numpy() 113 | } 114 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 115 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 116 | logging.info('successfully export estimator') 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /cosyvoice/bin/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import torchaudio 24 | from hyperpyyaml import load_hyperpyyaml 25 | from tqdm import tqdm 26 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 27 | from cosyvoice.dataset.dataset import Dataset 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='inference with your model') 32 | parser.add_argument('--config', required=True, help='config file') 33 | parser.add_argument('--prompt_data', required=True, help='prompt data file') 34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') 35 | parser.add_argument('--tts_text', required=True, help='tts input file') 36 | parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') 37 | parser.add_argument('--llm_model', required=True, help='llm model file') 38 | parser.add_argument('--flow_model', required=True, help='flow model file') 39 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file') 40 | parser.add_argument('--gpu', 41 | type=int, 42 | default=-1, 43 | help='gpu id for this rank, -1 for cpu') 44 | parser.add_argument('--mode', 45 | default='sft', 46 | choices=['sft', 'zero_shot'], 47 | help='inference mode') 48 | parser.add_argument('--result_dir', required=True, help='asr result file') 49 | args = parser.parse_args() 50 | print(args) 51 | return args 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | logging.basicConfig(level=logging.DEBUG, 57 | format='%(asctime)s %(levelname)s %(message)s') 58 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 59 | 60 | # Init cosyvoice models from configs 61 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 62 | device = torch.device('cuda' if use_cuda else 'cpu') 63 | try: 64 | with open(args.config, 'r') as f: 65 | configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path}) 66 | model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift']) 67 | except Exception: 68 | try: 69 | with open(args.config, 'r') as f: 70 | configs = load_hyperpyyaml(f) 71 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) 72 | except Exception: 73 | raise TypeError('no valid model_type!') 74 | 75 | model.load(args.llm_model, args.flow_model, args.hifigan_model) 76 | 77 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, 78 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) 79 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) 80 | 81 | sample_rate = configs['sample_rate'] 82 | del configs 83 | os.makedirs(args.result_dir, exist_ok=True) 84 | fn = os.path.join(args.result_dir, 'wav.scp') 85 | f = open(fn, 'w') 86 | with torch.no_grad(): 87 | for _, batch in tqdm(enumerate(test_data_loader)): 88 | utts = batch["utts"] 89 | assert len(utts) == 1, "inference mode only support batchsize 1" 90 | text_token = batch["text_token"].to(device) 91 | text_token_len = batch["text_token_len"].to(device) 92 | tts_index = batch["tts_index"] 93 | tts_text_token = batch["tts_text_token"].to(device) 94 | tts_text_token_len = batch["tts_text_token_len"].to(device) 95 | speech_token = batch["speech_token"].to(device) 96 | speech_token_len = batch["speech_token_len"].to(device) 97 | speech_feat = batch["speech_feat"].to(device) 98 | speech_feat_len = batch["speech_feat_len"].to(device) 99 | utt_embedding = batch["utt_embedding"].to(device) 100 | spk_embedding = batch["spk_embedding"].to(device) 101 | if args.mode == 'sft': 102 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 103 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} 104 | else: 105 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 106 | 'prompt_text': text_token, 'prompt_text_len': text_token_len, 107 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, 108 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 109 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 110 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 111 | tts_speeches = [] 112 | for model_output in model.tts(**model_input): 113 | tts_speeches.append(model_output['tts_speech']) 114 | tts_speeches = torch.concat(tts_speeches, dim=1) 115 | tts_key = '{}_{}'.format(utts[0], tts_index[0]) 116 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) 117 | torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile') 118 | f.write('{} {}\n'.format(tts_key, tts_fn)) 119 | f.flush() 120 | f.close() 121 | logging.info('Result wav.scp saved in {}'.format(fn)) 122 | 123 | 124 | if __name__ == '__main__': 125 | main() 126 | -------------------------------------------------------------------------------- /cosyvoice/bin/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | import argparse 17 | import datetime 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | from copy import deepcopy 21 | import os 22 | import torch 23 | import torch.distributed as dist 24 | import deepspeed 25 | 26 | from hyperpyyaml import load_hyperpyyaml 27 | 28 | from torch.distributed.elastic.multiprocessing.errors import record 29 | 30 | from cosyvoice.utils.executor import Executor 31 | from cosyvoice.utils.train_utils import ( 32 | init_distributed, 33 | init_dataset_and_dataloader, 34 | init_optimizer_and_scheduler, 35 | init_summarywriter, save_model, 36 | wrap_cuda_model, check_modify_and_save_config) 37 | 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description='training your network') 41 | parser.add_argument('--train_engine', 42 | default='torch_ddp', 43 | choices=['torch_ddp', 'deepspeed'], 44 | help='Engine for paralleled training') 45 | parser.add_argument('--model', required=True, help='model which will be trained') 46 | parser.add_argument('--config', required=True, help='config file') 47 | parser.add_argument('--train_data', required=True, help='train data file') 48 | parser.add_argument('--cv_data', required=True, help='cv data file') 49 | parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') 50 | parser.add_argument('--checkpoint', help='checkpoint model') 51 | parser.add_argument('--model_dir', required=True, help='save model dir') 52 | parser.add_argument('--tensorboard_dir', 53 | default='tensorboard', 54 | help='tensorboard log dir') 55 | parser.add_argument('--ddp.dist_backend', 56 | dest='dist_backend', 57 | default='nccl', 58 | choices=['nccl', 'gloo'], 59 | help='distributed backend') 60 | parser.add_argument('--num_workers', 61 | default=0, 62 | type=int, 63 | help='num of subprocess workers for reading') 64 | parser.add_argument('--prefetch', 65 | default=100, 66 | type=int, 67 | help='prefetch number') 68 | parser.add_argument('--pin_memory', 69 | action='store_true', 70 | default=False, 71 | help='Use pinned memory buffers used for reading') 72 | parser.add_argument('--use_amp', 73 | action='store_true', 74 | default=False, 75 | help='Use automatic mixed precision training') 76 | parser.add_argument('--deepspeed.save_states', 77 | dest='save_states', 78 | default='model_only', 79 | choices=['model_only', 'model+optimizer'], 80 | help='save model/optimizer states') 81 | parser.add_argument('--timeout', 82 | default=60, 83 | type=int, 84 | help='timeout (in seconds) of cosyvoice_join.') 85 | parser = deepspeed.add_config_arguments(parser) 86 | args = parser.parse_args() 87 | return args 88 | 89 | 90 | @record 91 | def main(): 92 | args = get_args() 93 | logging.basicConfig(level=logging.DEBUG, 94 | format='%(asctime)s %(levelname)s %(message)s') 95 | # gan train has some special initialization logic 96 | gan = True if args.model == 'hifigan' else False 97 | 98 | override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} 99 | if gan is True: 100 | override_dict.pop('hift') 101 | try: 102 | with open(args.config, 'r') as f: 103 | configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path}) 104 | except Exception: 105 | with open(args.config, 'r') as f: 106 | configs = load_hyperpyyaml(f, overrides=override_dict) 107 | if gan is True: 108 | configs['train_conf'] = configs['train_conf_gan'] 109 | configs['train_conf'].update(vars(args)) 110 | 111 | # Init env for ddp 112 | init_distributed(args) 113 | 114 | # Get dataset & dataloader 115 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ 116 | init_dataset_and_dataloader(args, configs, gan) 117 | 118 | # Do some sanity checks and save config to arsg.model_dir 119 | configs = check_modify_and_save_config(args, configs) 120 | 121 | # Tensorboard summary 122 | writer = init_summarywriter(args) 123 | 124 | # load checkpoint 125 | model = configs[args.model] 126 | start_step, start_epoch = 0, -1 127 | if args.checkpoint is not None: 128 | if os.path.exists(args.checkpoint): 129 | state_dict = torch.load(args.checkpoint, map_location='cpu') 130 | model.load_state_dict(state_dict, strict=False) 131 | if 'step' in state_dict: 132 | start_step = state_dict['step'] 133 | if 'epoch' in state_dict: 134 | start_epoch = state_dict['epoch'] 135 | else: 136 | logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) 137 | 138 | # Dispatch model from cpu to gpu 139 | model = wrap_cuda_model(args, model) 140 | 141 | # Get optimizer & scheduler 142 | model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) 143 | scheduler.set_step(start_step) 144 | if scheduler_d is not None: 145 | scheduler_d.set_step(start_step) 146 | 147 | # Save init checkpoints 148 | info_dict = deepcopy(configs['train_conf']) 149 | info_dict['step'] = start_step 150 | info_dict['epoch'] = start_epoch 151 | save_model(model, 'init', info_dict) 152 | 153 | # Get executor 154 | executor = Executor(gan=gan) 155 | executor.step = start_step 156 | 157 | # Init scaler, used for pytorch amp mixed precision training 158 | scaler = torch.cuda.amp.GradScaler() if args.use_amp else None 159 | print('start step {} start epoch {}'.format(start_step, start_epoch)) 160 | # Start training loop 161 | for epoch in range(start_epoch + 1, info_dict['max_epoch']): 162 | executor.epoch = epoch 163 | train_dataset.set_epoch(epoch) 164 | dist.barrier() 165 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) 166 | if gan is True: 167 | executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, 168 | writer, info_dict, scaler, group_join) 169 | else: 170 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join) 171 | dist.destroy_process_group(group_join) 172 | 173 | 174 | if __name__ == '__main__': 175 | main() 176 | -------------------------------------------------------------------------------- /cosyvoice/bin/train_dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | import argparse 17 | import datetime 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | from copy import deepcopy 21 | import os 22 | import torch 23 | import torch.distributed as dist 24 | import deepspeed 25 | 26 | from hyperpyyaml import load_hyperpyyaml 27 | 28 | from torch.distributed.elastic.multiprocessing.errors import record 29 | 30 | from cosyvoice.utils.executor_dpo import Executor 31 | from cosyvoice.utils.train_utils_dpo import ( 32 | init_distributed, 33 | init_dataset_and_dataloader, 34 | init_optimizer_and_scheduler, 35 | init_summarywriter, save_model, 36 | wrap_cuda_model, check_modify_and_save_config) 37 | 38 | 39 | def get_args(): 40 | parser = argparse.ArgumentParser(description='training your network') 41 | parser.add_argument('--train_engine', 42 | default='torch_ddp', 43 | choices=['torch_ddp', 'deepspeed'], 44 | help='Engine for paralleled training') 45 | parser.add_argument('--model', required=True, help='model which will be trained') 46 | parser.add_argument('--config', required=True, help='config file') 47 | parser.add_argument('--train_data', required=True, help='train data file') 48 | parser.add_argument('--cv_data', required=True, help='cv data file') 49 | parser.add_argument('--checkpoint', help='checkpoint model') 50 | parser.add_argument('--model_dir', required=True, help='save model dir') 51 | parser.add_argument('--tensorboard_dir', 52 | default='tensorboard', 53 | help='tensorboard log dir') 54 | parser.add_argument('--ddp.dist_backend', 55 | dest='dist_backend', 56 | default='nccl', 57 | choices=['nccl', 'gloo'], 58 | help='distributed backend') 59 | parser.add_argument('--num_workers', 60 | default=0, 61 | type=int, 62 | help='num of subprocess workers for reading') 63 | parser.add_argument('--prefetch', 64 | default=100, 65 | type=int, 66 | help='prefetch number') 67 | parser.add_argument('--pin_memory', 68 | action='store_true', 69 | default=False, 70 | help='Use pinned memory buffers used for reading') 71 | parser.add_argument('--use_amp', 72 | action='store_true', 73 | default=False, 74 | help='Use automatic mixed precision training') 75 | parser.add_argument('--deepspeed.save_states', 76 | dest='save_states', 77 | default='model_only', 78 | choices=['model_only', 'model+optimizer'], 79 | help='save model/optimizer states') 80 | parser.add_argument('--timeout', 81 | default=60, 82 | type=int, 83 | help='timeout (in seconds) of cosyvoice_join.') 84 | parser.add_argument('--dpo', 85 | action='store_true', 86 | default=False, 87 | help='Use Direct Preference Optimization') 88 | parser.add_argument('--beta', 89 | default=0.01, 90 | type=float, 91 | help='beta of dpo training') 92 | parser = deepspeed.add_config_arguments(parser) 93 | args = parser.parse_args() 94 | return args 95 | 96 | 97 | @record 98 | def main(): 99 | args = get_args() 100 | logging.basicConfig(level=logging.DEBUG, 101 | format='%(asctime)s %(levelname)s %(message)s') 102 | # gan train has some special initialization logic 103 | gan = True if args.model == 'hifigan' else False 104 | 105 | override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model} 106 | if gan is True: 107 | override_dict.pop('hift') 108 | with open(args.config, 'r') as f: 109 | configs = load_hyperpyyaml(f, overrides=override_dict) 110 | if gan is True: 111 | configs['train_conf'] = configs['train_conf_gan'] 112 | configs['train_conf'].update(vars(args)) 113 | 114 | # Init env for ddp 115 | init_distributed(args) 116 | 117 | # Get dataset & dataloader 118 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \ 119 | init_dataset_and_dataloader(args, configs, gan) 120 | 121 | # Do some sanity checks and save config to arsg.model_dir 122 | configs = check_modify_and_save_config(args, configs) 123 | 124 | # Tensorboard summary 125 | writer = init_summarywriter(args) 126 | 127 | # load checkpoint 128 | model = configs[args.model] 129 | ref_model = None 130 | if args.dpo: 131 | ref_model = deepcopy(model) 132 | start_step, start_epoch = 0, -1 133 | if args.checkpoint is not None: 134 | if os.path.exists(args.checkpoint): 135 | state_dict = torch.load(args.checkpoint, map_location='cpu') 136 | model.load_state_dict(state_dict, strict=False) 137 | if args.dpo: 138 | ref_model.load_state_dict(state_dict, strict=False) 139 | if 'step' in state_dict: 140 | start_step = state_dict['step'] 141 | if 'epoch' in state_dict: 142 | start_epoch = state_dict['epoch'] 143 | else: 144 | logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint)) 145 | 146 | # Dispatch model from cpu to gpu 147 | model = wrap_cuda_model(args, model) 148 | if args.dpo: 149 | ref_model = wrap_cuda_model(args, ref_model) 150 | 151 | # Get optimizer & scheduler 152 | model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan) 153 | if args.dpo: 154 | ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan) 155 | scheduler.set_step(start_step) 156 | if scheduler_d is not None: 157 | scheduler_d.set_step(start_step) 158 | 159 | # Save init checkpoints 160 | info_dict = deepcopy(configs['train_conf']) 161 | info_dict['step'] = start_step 162 | info_dict['epoch'] = start_epoch 163 | save_model(model, 'init', info_dict) 164 | 165 | # Get executor 166 | executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta) 167 | executor.step = start_step 168 | 169 | # Init scaler, used for pytorch amp mixed precision training 170 | scaler = torch.cuda.amp.GradScaler() if args.use_amp else None 171 | print('start step {} start epoch {}'.format(start_step, start_epoch)) 172 | # Start training loop 173 | for epoch in range(start_epoch + 1, info_dict['max_epoch']): 174 | executor.epoch = epoch 175 | train_dataset.set_epoch(epoch) 176 | dist.barrier() 177 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout)) 178 | if gan is True: 179 | executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, 180 | writer, info_dict, scaler, group_join) 181 | else: 182 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model) 183 | dist.destroy_process_group(group_join) 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/cli/__init__.py -------------------------------------------------------------------------------- /cosyvoice/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/dataset/__init__.py -------------------------------------------------------------------------------- /cosyvoice/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import json 18 | import math 19 | from functools import partial 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch.utils.data import IterableDataset 24 | from cosyvoice.utils.file_utils import read_lists, read_json_lists 25 | 26 | 27 | class Processor(IterableDataset): 28 | 29 | def __init__(self, source, f, *args, **kw): 30 | assert callable(f) 31 | self.source = source 32 | self.f = f 33 | self.args = args 34 | self.kw = kw 35 | 36 | def set_epoch(self, epoch): 37 | self.source.set_epoch(epoch) 38 | 39 | def __iter__(self): 40 | """ Return an iterator over the source dataset processed by the 41 | given processor. 42 | """ 43 | assert self.source is not None 44 | assert callable(self.f) 45 | return self.f(iter(self.source), *self.args, **self.kw) 46 | 47 | def apply(self, f): 48 | assert callable(f) 49 | return Processor(self, f, *self.args, **self.kw) 50 | 51 | 52 | class DistributedSampler: 53 | 54 | def __init__(self, shuffle=True, partition=True): 55 | self.epoch = -1 56 | self.update() 57 | self.shuffle = shuffle 58 | self.partition = partition 59 | 60 | def update(self): 61 | assert dist.is_available() 62 | if dist.is_initialized(): 63 | self.rank = dist.get_rank() 64 | self.world_size = dist.get_world_size() 65 | else: 66 | self.rank = 0 67 | self.world_size = 1 68 | worker_info = torch.utils.data.get_worker_info() 69 | if worker_info is None: 70 | self.worker_id = 0 71 | self.num_workers = 1 72 | else: 73 | self.worker_id = worker_info.id 74 | self.num_workers = worker_info.num_workers 75 | return dict(rank=self.rank, 76 | world_size=self.world_size, 77 | worker_id=self.worker_id, 78 | num_workers=self.num_workers) 79 | 80 | def set_epoch(self, epoch): 81 | self.epoch = epoch 82 | 83 | def sample(self, data): 84 | """ Sample data according to rank/world_size/num_workers 85 | 86 | Args: 87 | data(List): input data list 88 | 89 | Returns: 90 | List: data list after sample 91 | """ 92 | data = list(range(len(data))) 93 | # force datalist even 94 | if self.partition: 95 | if self.shuffle: 96 | random.Random(self.epoch).shuffle(data) 97 | if len(data) < self.world_size: 98 | data = data * math.ceil(self.world_size / len(data)) 99 | data = data[:self.world_size] 100 | data = data[self.rank::self.world_size] 101 | if len(data) < self.num_workers: 102 | data = data * math.ceil(self.num_workers / len(data)) 103 | data = data[:self.num_workers] 104 | data = data[self.worker_id::self.num_workers] 105 | return data 106 | 107 | 108 | class DataList(IterableDataset): 109 | 110 | def __init__(self, lists, shuffle=True, partition=True): 111 | self.lists = lists 112 | self.sampler = DistributedSampler(shuffle, partition) 113 | 114 | def set_epoch(self, epoch): 115 | self.sampler.set_epoch(epoch) 116 | 117 | def __iter__(self): 118 | sampler_info = self.sampler.update() 119 | indexes = self.sampler.sample(self.lists) 120 | for index in indexes: 121 | data = dict(src=self.lists[index]) 122 | data.update(sampler_info) 123 | yield data 124 | 125 | 126 | def Dataset(data_list_file, 127 | data_pipeline, 128 | mode='train', 129 | gan=False, 130 | shuffle=True, 131 | partition=True, 132 | tts_file='', 133 | prompt_utt2data=''): 134 | """ Construct dataset from arguments 135 | 136 | We have two shuffle stage in the Dataset. The first is global 137 | shuffle at shards tar/raw file level. The second is global shuffle 138 | at training samples level. 139 | 140 | Args: 141 | data_type(str): raw/shard 142 | tokenizer (BaseTokenizer): tokenizer to tokenize 143 | partition(bool): whether to do data partition in terms of rank 144 | """ 145 | assert mode in ['train', 'inference'] 146 | lists = read_lists(data_list_file) 147 | if mode == 'inference': 148 | with open(tts_file) as f: 149 | tts_data = json.load(f) 150 | utt2lists = read_json_lists(prompt_utt2data) 151 | # filter unnecessary file in inference mode 152 | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists}) 153 | dataset = DataList(lists, 154 | shuffle=shuffle, 155 | partition=partition) 156 | if mode == 'inference': 157 | # map partial arg to parquet_opener func in inference mode 158 | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data) 159 | if gan is True: 160 | # map partial arg to padding func in gan mode 161 | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan) 162 | for func in data_pipeline: 163 | dataset = Processor(dataset, func, mode=mode) 164 | return dataset 165 | -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py 55 | # x in (B, T, D) 56 | if x2.shape[1] > 40: 57 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 58 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 59 | mode='linear') 60 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 61 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 62 | else: 63 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 64 | if x1.shape[1] != 0: 65 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 66 | x = torch.concat([x1, x2], dim=2) 67 | else: 68 | x = x2 69 | out = self.model(x).transpose(1, 2).contiguous() 70 | return out, mel_len1 + mel_len2 71 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | try: 17 | from torch.nn.utils.parametrizations import weight_norm 18 | except ImportError: 19 | from torch.nn.utils import weight_norm 20 | 21 | 22 | class ConvRNNF0Predictor(nn.Module): 23 | def __init__(self, 24 | num_class: int = 1, 25 | in_channels: int = 80, 26 | cond_channels: int = 512 27 | ): 28 | super().__init__() 29 | 30 | self.num_class = num_class 31 | self.condnet = nn.Sequential( 32 | weight_norm( 33 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 34 | ), 35 | nn.ELU(), 36 | weight_norm( 37 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 38 | ), 39 | nn.ELU(), 40 | weight_norm( 41 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 42 | ), 43 | nn.ELU(), 44 | weight_norm( 45 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 46 | ), 47 | nn.ELU(), 48 | weight_norm( 49 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 50 | ), 51 | nn.ELU(), 52 | ) 53 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | x = self.condnet(x) 57 | x = x.transpose(1, 2) 58 | return torch.abs(self.classifier(x).squeeze(-1)) 59 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss 6 | from cosyvoice.utils.losses import tpr_loss, mel_loss 7 | 8 | 9 | class HiFiGan(nn.Module): 10 | def __init__(self, generator, discriminator, mel_spec_transform, 11 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, 12 | tpr_loss_weight=1.0, tpr_loss_tau=0.04): 13 | super(HiFiGan, self).__init__() 14 | self.generator = generator 15 | self.discriminator = discriminator 16 | self.mel_spec_transform = mel_spec_transform 17 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight 18 | self.feat_match_loss_weight = feat_match_loss_weight 19 | self.tpr_loss_weight = tpr_loss_weight 20 | self.tpr_loss_tau = tpr_loss_tau 21 | 22 | def forward( 23 | self, 24 | batch: dict, 25 | device: torch.device, 26 | ) -> Dict[str, Optional[torch.Tensor]]: 27 | if batch['turn'] == 'generator': 28 | return self.forward_generator(batch, device) 29 | else: 30 | return self.forward_discriminator(batch, device) 31 | 32 | def forward_generator(self, batch, device): 33 | real_speech = batch['speech'].to(device) 34 | pitch_feat = batch['pitch_feat'].to(device) 35 | # 1. calculate generator outputs 36 | generated_speech, generated_f0 = self.generator(batch, device) 37 | # 2. calculate discriminator outputs 38 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 39 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] 40 | loss_gen, _ = generator_loss(y_d_gs) 41 | loss_fm = feature_loss(fmap_rs, fmap_gs) 42 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) 43 | if self.tpr_loss_weight != 0: 44 | loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau) 45 | else: 46 | loss_tpr = torch.zeros(1).to(device) 47 | loss_f0 = F.l1_loss(generated_f0, pitch_feat) 48 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ 49 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \ 50 | self.tpr_loss_weight * loss_tpr + loss_f0 51 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} 52 | 53 | def forward_discriminator(self, batch, device): 54 | real_speech = batch['speech'].to(device) 55 | # 1. calculate generator outputs 56 | with torch.no_grad(): 57 | generated_speech, generated_f0 = self.generator(batch, device) 58 | # 2. calculate discriminator outputs 59 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach()) 60 | # 3. calculate discriminator losses, tpr losses [Optional] 61 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) 62 | if self.tpr_loss_weight != 0: 63 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 64 | else: 65 | loss_tpr = torch.zeros(1).to(device) 66 | loss = loss_disc + self.tpr_loss_weight * loss_tpr 67 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} 68 | -------------------------------------------------------------------------------- /cosyvoice/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | from functools import lru_cache 4 | from typing import Optional 5 | import torch 6 | from transformers import AutoTokenizer 7 | from whisper.tokenizer import Tokenizer 8 | 9 | import tiktoken 10 | 11 | LANGUAGES = { 12 | "en": "english", 13 | "zh": "chinese", 14 | "de": "german", 15 | "es": "spanish", 16 | "ru": "russian", 17 | "ko": "korean", 18 | "fr": "french", 19 | "ja": "japanese", 20 | "pt": "portuguese", 21 | "tr": "turkish", 22 | "pl": "polish", 23 | "ca": "catalan", 24 | "nl": "dutch", 25 | "ar": "arabic", 26 | "sv": "swedish", 27 | "it": "italian", 28 | "id": "indonesian", 29 | "hi": "hindi", 30 | "fi": "finnish", 31 | "vi": "vietnamese", 32 | "he": "hebrew", 33 | "uk": "ukrainian", 34 | "el": "greek", 35 | "ms": "malay", 36 | "cs": "czech", 37 | "ro": "romanian", 38 | "da": "danish", 39 | "hu": "hungarian", 40 | "ta": "tamil", 41 | "no": "norwegian", 42 | "th": "thai", 43 | "ur": "urdu", 44 | "hr": "croatian", 45 | "bg": "bulgarian", 46 | "lt": "lithuanian", 47 | "la": "latin", 48 | "mi": "maori", 49 | "ml": "malayalam", 50 | "cy": "welsh", 51 | "sk": "slovak", 52 | "te": "telugu", 53 | "fa": "persian", 54 | "lv": "latvian", 55 | "bn": "bengali", 56 | "sr": "serbian", 57 | "az": "azerbaijani", 58 | "sl": "slovenian", 59 | "kn": "kannada", 60 | "et": "estonian", 61 | "mk": "macedonian", 62 | "br": "breton", 63 | "eu": "basque", 64 | "is": "icelandic", 65 | "hy": "armenian", 66 | "ne": "nepali", 67 | "mn": "mongolian", 68 | "bs": "bosnian", 69 | "kk": "kazakh", 70 | "sq": "albanian", 71 | "sw": "swahili", 72 | "gl": "galician", 73 | "mr": "marathi", 74 | "pa": "punjabi", 75 | "si": "sinhala", 76 | "km": "khmer", 77 | "sn": "shona", 78 | "yo": "yoruba", 79 | "so": "somali", 80 | "af": "afrikaans", 81 | "oc": "occitan", 82 | "ka": "georgian", 83 | "be": "belarusian", 84 | "tg": "tajik", 85 | "sd": "sindhi", 86 | "gu": "gujarati", 87 | "am": "amharic", 88 | "yi": "yiddish", 89 | "lo": "lao", 90 | "uz": "uzbek", 91 | "fo": "faroese", 92 | "ht": "haitian creole", 93 | "ps": "pashto", 94 | "tk": "turkmen", 95 | "nn": "nynorsk", 96 | "mt": "maltese", 97 | "sa": "sanskrit", 98 | "lb": "luxembourgish", 99 | "my": "myanmar", 100 | "bo": "tibetan", 101 | "tl": "tagalog", 102 | "mg": "malagasy", 103 | "as": "assamese", 104 | "tt": "tatar", 105 | "haw": "hawaiian", 106 | "ln": "lingala", 107 | "ha": "hausa", 108 | "ba": "bashkir", 109 | "jw": "javanese", 110 | "su": "sundanese", 111 | "yue": "cantonese", 112 | "minnan": "minnan", 113 | "wuyu": "wuyu", 114 | "dialect": "dialect", 115 | "zh/en": "zh/en", 116 | "en/zh": "en/zh", 117 | } 118 | 119 | # language code lookup by name, with a few language aliases 120 | TO_LANGUAGE_CODE = { 121 | **{language: code for code, language in LANGUAGES.items()}, 122 | "burmese": "my", 123 | "valencian": "ca", 124 | "flemish": "nl", 125 | "haitian": "ht", 126 | "letzeburgesch": "lb", 127 | "pushto": "ps", 128 | "panjabi": "pa", 129 | "moldavian": "ro", 130 | "moldovan": "ro", 131 | "sinhalese": "si", 132 | "castilian": "es", 133 | "mandarin": "zh", 134 | } 135 | 136 | AUDIO_EVENT = { 137 | "ASR": "ASR", 138 | "AED": "AED", 139 | "SER": "SER", 140 | "Speech": "Speech", 141 | "/Speech": "/Speech", 142 | "BGM": "BGM", 143 | "/BGM": "/BGM", 144 | "Laughter": "Laughter", 145 | "/Laughter": "/Laughter", 146 | "Applause": "Applause", 147 | "/Applause": "/Applause", 148 | } 149 | 150 | EMOTION = { 151 | "HAPPY": "HAPPY", 152 | "SAD": "SAD", 153 | "ANGRY": "ANGRY", 154 | "NEUTRAL": "NEUTRAL", 155 | } 156 | 157 | TTS_Vocal_Token = { 158 | "TTS/B": "TTS/B", 159 | "TTS/O": "TTS/O", 160 | "TTS/Q": "TTS/Q", 161 | "TTS/A": "TTS/A", 162 | "TTS/CO": "TTS/CO", 163 | "TTS/CL": "TTS/CL", 164 | "TTS/H": "TTS/H", 165 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} 166 | } 167 | 168 | 169 | @lru_cache(maxsize=None) 170 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 171 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 172 | ranks = { 173 | base64.b64decode(token): int(rank) 174 | for token, rank in (line.split() for line in open(vocab_path) if line) 175 | } 176 | n_vocab = len(ranks) 177 | special_tokens = {} 178 | 179 | specials = [ 180 | "<|endoftext|>", 181 | "<|startoftranscript|>", 182 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 183 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], 184 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], 185 | "<|translate|>", 186 | "<|transcribe|>", 187 | "<|startoflm|>", 188 | "<|startofprev|>", 189 | "<|nospeech|>", 190 | "<|notimestamps|>", 191 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR 192 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS 193 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 194 | ] 195 | 196 | for token in specials: 197 | special_tokens[token] = n_vocab 198 | n_vocab += 1 199 | 200 | return tiktoken.Encoding( 201 | name=os.path.basename(vocab_path), 202 | explicit_n_vocab=n_vocab, 203 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 204 | mergeable_ranks=ranks, 205 | special_tokens=special_tokens, 206 | ) 207 | 208 | 209 | @lru_cache(maxsize=None) 210 | def get_tokenizer( 211 | multilingual: bool, 212 | *, 213 | num_languages: int = 99, 214 | language: Optional[str] = None, 215 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 216 | ) -> Tokenizer: 217 | if language is not None: 218 | language = language.lower() 219 | if language not in LANGUAGES: 220 | if language in TO_LANGUAGE_CODE: 221 | language = TO_LANGUAGE_CODE[language] 222 | else: 223 | raise ValueError(f"Unsupported language: {language}") 224 | 225 | if multilingual: 226 | encoding_name = "multilingual_zh_ja_yue_char_del" 227 | language = language or "en" 228 | task = task or "transcribe" 229 | else: 230 | encoding_name = "gpt2" 231 | language = None 232 | task = None 233 | 234 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 235 | 236 | return Tokenizer( 237 | encoding=encoding, num_languages=num_languages, language=language, task=task 238 | ) 239 | 240 | 241 | class QwenTokenizer(): 242 | def __init__(self, token_path, skip_special_tokens=True): 243 | super().__init__() 244 | # NOTE: non-chat model, all these special tokens keep randomly initialized. 245 | special_tokens = { 246 | 'eos_token': '<|endoftext|>', 247 | 'pad_token': '<|endoftext|>', 248 | 'additional_special_tokens': [ 249 | '<|im_start|>', '<|im_end|>', '<|endofprompt|>', 250 | '[breath]', '', '', '[noise]', 251 | '[laughter]', '[cough]', '[clucking]', '[accent]', 252 | '[quick_breath]', 253 | "", "", 254 | "[hissing]", "[sigh]", "[vocalized-noise]", 255 | "[lipsmack]", "[mn]" 256 | ] 257 | } 258 | self.special_tokens = special_tokens 259 | self.tokenizer = AutoTokenizer.from_pretrained(token_path) 260 | self.tokenizer.add_special_tokens(special_tokens) 261 | self.skip_special_tokens = skip_special_tokens 262 | 263 | def encode(self, text, **kwargs): 264 | tokens = self.tokenizer([text], return_tensors="pt") 265 | tokens = tokens["input_ids"][0].cpu().tolist() 266 | return tokens 267 | 268 | def decode(self, tokens): 269 | tokens = torch.tensor(tokens, dtype=torch.int64) 270 | text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] 271 | return text 272 | 273 | 274 | @lru_cache(maxsize=None) 275 | def get_qwen_tokenizer( 276 | token_path: str, 277 | skip_special_tokens: bool 278 | ) -> QwenTokenizer: 279 | return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) 280 | -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/transformer/__init__.py -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /cosyvoice/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /cosyvoice/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /cosyvoice/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /cosyvoice/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/utils/__init__.py -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from cosyvoice.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding 34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling 35 | from cosyvoice.llm.llm import TransformerLM, Qwen2LM 36 | from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec 37 | from cosyvoice.hifigan.generator import HiFTGenerator 38 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 39 | 40 | 41 | COSYVOICE_ACTIVATION_CLASSES = { 42 | "hardtanh": torch.nn.Hardtanh, 43 | "tanh": torch.nn.Tanh, 44 | "relu": torch.nn.ReLU, 45 | "selu": torch.nn.SELU, 46 | "swish": getattr(torch.nn, "SiLU", Swish), 47 | "gelu": torch.nn.GELU, 48 | } 49 | 50 | COSYVOICE_SUBSAMPLE_CLASSES = { 51 | "linear": LinearNoSubsampling, 52 | "linear_legacy": LegacyLinearNoSubsampling, 53 | "embed": EmbedinigNoSubsampling, 54 | "conv1d2": Conv1dSubsampling2, 55 | "conv2d": Conv2dSubsampling4, 56 | "conv2d6": Conv2dSubsampling6, 57 | "conv2d8": Conv2dSubsampling8, 58 | 'paraformer_dummy': torch.nn.Identity 59 | } 60 | 61 | COSYVOICE_EMB_CLASSES = { 62 | "embed": PositionalEncoding, 63 | "abs_pos": PositionalEncoding, 64 | "rel_pos": RelPositionalEncoding, 65 | "rel_pos_espnet": EspnetRelPositionalEncoding, 66 | "no_pos": NoPositionalEncoding, 67 | "abs_pos_whisper": WhisperPositionalEncoding, 68 | "embed_learnable_pe": LearnablePositionalEncoding, 69 | } 70 | 71 | COSYVOICE_ATTENTION_CLASSES = { 72 | "selfattn": MultiHeadedAttention, 73 | "rel_selfattn": RelPositionMultiHeadedAttention, 74 | } 75 | 76 | 77 | def get_model_type(configs): 78 | # NOTE CosyVoice2Model inherits CosyVoiceModel 79 | if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 80 | return CosyVoiceModel 81 | if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 82 | return CosyVoice2Model 83 | raise TypeError('No valid model type found!') 84 | -------------------------------------------------------------------------------- /cosyvoice/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # Modified from ESPnet(https://github.com/espnet/espnet) 17 | """Unility functions for Transformer.""" 18 | 19 | import queue 20 | import random 21 | from typing import List 22 | 23 | import numpy as np 24 | import torch 25 | 26 | IGNORE_ID = -1 27 | 28 | 29 | def pad_list(xs: List[torch.Tensor], pad_value: int): 30 | """Perform padding for the list of tensors. 31 | 32 | Args: 33 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 34 | pad_value (float): Value for padding. 35 | 36 | Returns: 37 | Tensor: Padded tensor (B, Tmax, `*`). 38 | 39 | Examples: 40 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 41 | >>> x 42 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 43 | >>> pad_list(x, 0) 44 | tensor([[1., 1., 1., 1.], 45 | [1., 1., 0., 0.], 46 | [1., 0., 0., 0.]]) 47 | 48 | """ 49 | max_len = max([len(item) for item in xs]) 50 | batchs = len(xs) 51 | ndim = xs[0].ndim 52 | if ndim == 1: 53 | pad_res = torch.zeros(batchs, 54 | max_len, 55 | dtype=xs[0].dtype, 56 | device=xs[0].device) 57 | elif ndim == 2: 58 | pad_res = torch.zeros(batchs, 59 | max_len, 60 | xs[0].shape[1], 61 | dtype=xs[0].dtype, 62 | device=xs[0].device) 63 | elif ndim == 3: 64 | pad_res = torch.zeros(batchs, 65 | max_len, 66 | xs[0].shape[1], 67 | xs[0].shape[2], 68 | dtype=xs[0].dtype, 69 | device=xs[0].device) 70 | else: 71 | raise ValueError(f"Unsupported ndim: {ndim}") 72 | pad_res.fill_(pad_value) 73 | for i in range(batchs): 74 | pad_res[i, :len(xs[i])] = xs[i] 75 | return pad_res 76 | 77 | 78 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 79 | ignore_label: int) -> torch.Tensor: 80 | """Calculate accuracy. 81 | 82 | Args: 83 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 84 | pad_targets (LongTensor): Target label tensors (B, Lmax). 85 | ignore_label (int): Ignore label id. 86 | 87 | Returns: 88 | torch.Tensor: Accuracy value (0.0 - 1.0). 89 | 90 | """ 91 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 92 | pad_outputs.size(1)).argmax(2) 93 | mask = pad_targets != ignore_label 94 | numerator = torch.sum( 95 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 96 | denominator = torch.sum(mask) 97 | return (numerator / denominator).detach() 98 | 99 | 100 | def get_padding(kernel_size, dilation=1): 101 | return int((kernel_size * dilation - dilation) / 2) 102 | 103 | 104 | def init_weights(m, mean=0.0, std=0.01): 105 | classname = m.__class__.__name__ 106 | if classname.find("Conv") != -1: 107 | m.weight.data.normal_(mean, std) 108 | 109 | 110 | # Repetition Aware Sampling in VALL-E 2 111 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): 112 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) 113 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() 114 | if rep_num >= win_size * tau_r: 115 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) 116 | return top_ids 117 | 118 | 119 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): 120 | prob, indices = [], [] 121 | cum_prob = 0.0 122 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) 123 | for i in range(len(sorted_idx)): 124 | # sampling both top-p and numbers. 125 | if cum_prob < top_p and len(prob) < top_k: 126 | cum_prob += sorted_value[i] 127 | prob.append(sorted_value[i]) 128 | indices.append(sorted_idx[i]) 129 | else: 130 | break 131 | prob = torch.tensor(prob).to(weighted_scores) 132 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) 133 | top_ids = indices[prob.multinomial(1, replacement=True)] 134 | return top_ids 135 | 136 | 137 | def random_sampling(weighted_scores, decoded_tokens, sampling): 138 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) 139 | return top_ids 140 | 141 | 142 | def fade_in_out(fade_in_mel, fade_out_mel, window): 143 | device = fade_in_mel.device 144 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() 145 | mel_overlap_len = int(window.shape[0] / 2) 146 | if fade_in_mel.device == torch.device('cpu'): 147 | fade_in_mel = fade_in_mel.clone() 148 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ 149 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] 150 | return fade_in_mel.to(device) 151 | 152 | 153 | def set_all_random_seed(seed): 154 | random.seed(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | torch.cuda.manual_seed_all(seed) 158 | 159 | 160 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 161 | assert mask.dtype == torch.bool 162 | assert dtype in [torch.float32, torch.bfloat16, torch.float16] 163 | mask = mask.to(dtype) 164 | # attention mask bias 165 | # NOTE(Mddct): torch.finfo jit issues 166 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min 167 | mask = (1.0 - mask) * -1.0e+10 168 | return mask 169 | 170 | 171 | class TrtContextWrapper: 172 | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): 173 | self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) 174 | self.trt_engine = trt_engine 175 | for _ in range(trt_concurrent): 176 | trt_context = trt_engine.create_execution_context() 177 | trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) 178 | assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) 179 | self.trt_context_pool.put([trt_context, trt_stream]) 180 | assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' 181 | 182 | def acquire_estimator(self): 183 | return self.trt_context_pool.get(), self.trt_engine 184 | 185 | def release_estimator(self, context, stream): 186 | self.trt_context_pool.put([context, stream]) 187 | -------------------------------------------------------------------------------- /cosyvoice/utils/executor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | from contextlib import nullcontext 18 | import os 19 | 20 | import torch 21 | import torch.distributed as dist 22 | 23 | from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join 24 | 25 | 26 | class Executor: 27 | 28 | def __init__(self, gan: bool = False): 29 | self.gan = gan 30 | self.step = 0 31 | self.epoch = 0 32 | self.rank = int(os.environ.get('RANK', 0)) 33 | self.device = torch.device('cuda:{}'.format(self.rank)) 34 | 35 | def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join): 36 | ''' Train one epoch 37 | ''' 38 | 39 | lr = optimizer.param_groups[0]['lr'] 40 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) 41 | logging.info('using accumulate grad, new batch size is {} times' 42 | ' larger than before'.format(info_dict['accum_grad'])) 43 | # A context manager to be used in conjunction with an instance of 44 | # torch.nn.parallel.DistributedDataParallel to be able to train 45 | # with uneven inputs across participating processes. 46 | model.train() 47 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext 48 | with model_context(): 49 | for batch_idx, batch_dict in enumerate(train_data_loader): 50 | info_dict["tag"] = "TRAIN" 51 | info_dict["step"] = self.step 52 | info_dict["epoch"] = self.epoch 53 | info_dict["batch_idx"] = batch_idx 54 | if cosyvoice_join(group_join, info_dict): 55 | break 56 | 57 | # Disable gradient synchronizations across DDP processes. 58 | # Within this context, gradients will be accumulated on module 59 | # variables, which will later be synchronized. 60 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: 61 | context = model.no_sync 62 | # Used for single gpu training and DDP gradient synchronization 63 | # processes. 64 | else: 65 | context = nullcontext 66 | 67 | with context(): 68 | info_dict = batch_forward(model, batch_dict, scaler, info_dict) 69 | info_dict = batch_backward(model, scaler, info_dict) 70 | 71 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) 72 | log_per_step(writer, info_dict) 73 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save 74 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ 75 | (batch_idx + 1) % info_dict["accum_grad"] == 0: 76 | dist.barrier() 77 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) 78 | model.train() 79 | if (batch_idx + 1) % info_dict["accum_grad"] == 0: 80 | self.step += 1 81 | dist.barrier() 82 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) 83 | 84 | def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader, 85 | writer, info_dict, scaler, group_join): 86 | ''' Train one epoch 87 | ''' 88 | 89 | lr = optimizer.param_groups[0]['lr'] 90 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank)) 91 | logging.info('using accumulate grad, new batch size is {} times' 92 | ' larger than before'.format(info_dict['accum_grad'])) 93 | # A context manager to be used in conjunction with an instance of 94 | # torch.nn.parallel.DistributedDataParallel to be able to train 95 | # with uneven inputs across participating processes. 96 | model.train() 97 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext 98 | with model_context(): 99 | for batch_idx, batch_dict in enumerate(train_data_loader): 100 | info_dict["tag"] = "TRAIN" 101 | info_dict["step"] = self.step 102 | info_dict["epoch"] = self.epoch 103 | info_dict["batch_idx"] = batch_idx 104 | if cosyvoice_join(group_join, info_dict): 105 | break 106 | 107 | # Disable gradient synchronizations across DDP processes. 108 | # Within this context, gradients will be accumulated on module 109 | # variables, which will later be synchronized. 110 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0: 111 | context = model.no_sync 112 | # Used for single gpu training and DDP gradient synchronization 113 | # processes. 114 | else: 115 | context = nullcontext 116 | 117 | with context(): 118 | batch_dict['turn'] = 'discriminator' 119 | info_dict = batch_forward(model, batch_dict, scaler, info_dict) 120 | info_dict = batch_backward(model, scaler, info_dict) 121 | info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict) 122 | optimizer.zero_grad() 123 | log_per_step(writer, info_dict) 124 | with context(): 125 | batch_dict['turn'] = 'generator' 126 | info_dict = batch_forward(model, batch_dict, scaler, info_dict) 127 | info_dict = batch_backward(model, scaler, info_dict) 128 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict) 129 | optimizer_d.zero_grad() 130 | log_per_step(writer, info_dict) 131 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save 132 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \ 133 | (batch_idx + 1) % info_dict["accum_grad"] == 0: 134 | dist.barrier() 135 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False) 136 | model.train() 137 | if (batch_idx + 1) % info_dict["accum_grad"] == 0: 138 | self.step += 1 139 | dist.barrier() 140 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True) 141 | 142 | @torch.inference_mode() 143 | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True): 144 | ''' Cross validation on 145 | ''' 146 | logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank)) 147 | model.eval() 148 | total_num_utts, total_loss_dict = 0, {} # avoid division by 0 149 | for batch_idx, batch_dict in enumerate(cv_data_loader): 150 | info_dict["tag"] = "CV" 151 | info_dict["step"] = self.step 152 | info_dict["epoch"] = self.epoch 153 | info_dict["batch_idx"] = batch_idx 154 | 155 | num_utts = len(batch_dict["utts"]) 156 | total_num_utts += num_utts 157 | 158 | if self.gan is True: 159 | batch_dict['turn'] = 'generator' 160 | info_dict = batch_forward(model, batch_dict, None, info_dict) 161 | 162 | for k, v in info_dict['loss_dict'].items(): 163 | if k not in total_loss_dict: 164 | total_loss_dict[k] = [] 165 | total_loss_dict[k].append(v.item() * num_utts) 166 | log_per_step(None, info_dict) 167 | for k, v in total_loss_dict.items(): 168 | total_loss_dict[k] = sum(v) / total_num_utts 169 | info_dict['loss_dict'] = total_loss_dict 170 | log_per_save(writer, info_dict) 171 | model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1) 172 | save_model(model, model_name, info_dict) 173 | -------------------------------------------------------------------------------- /cosyvoice/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu) 3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import json 19 | import torch 20 | import torchaudio 21 | import logging 22 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 23 | logging.basicConfig(level=logging.DEBUG, 24 | format='%(asctime)s %(levelname)s %(message)s') 25 | 26 | 27 | def read_lists(list_file): 28 | lists = [] 29 | with open(list_file, 'r', encoding='utf8') as fin: 30 | for line in fin: 31 | lists.append(line.strip()) 32 | return lists 33 | 34 | 35 | def read_json_lists(list_file): 36 | lists = read_lists(list_file) 37 | results = {} 38 | for fn in lists: 39 | with open(fn, 'r', encoding='utf8') as fin: 40 | results.update(json.load(fin)) 41 | return results 42 | 43 | 44 | def load_wav(wav, target_sr): 45 | speech, sample_rate = torchaudio.load(wav, backend='soundfile') 46 | speech = speech.mean(dim=0, keepdim=True) 47 | if sample_rate != target_sr: 48 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 49 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 50 | return speech 51 | 52 | 53 | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): 54 | import tensorrt as trt 55 | logging.info("Converting onnx to trt...") 56 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 57 | logger = trt.Logger(trt.Logger.INFO) 58 | builder = trt.Builder(logger) 59 | network = builder.create_network(network_flags) 60 | parser = trt.OnnxParser(network, logger) 61 | config = builder.create_builder_config() 62 | config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB 63 | if fp16: 64 | config.set_flag(trt.BuilderFlag.FP16) 65 | profile = builder.create_optimization_profile() 66 | # load onnx model 67 | with open(onnx_model, "rb") as f: 68 | if not parser.parse(f.read()): 69 | for error in range(parser.num_errors): 70 | print(parser.get_error(error)) 71 | raise ValueError('failed to parse {}'.format(onnx_model)) 72 | # set input shapes 73 | for i in range(len(trt_kwargs['input_names'])): 74 | profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) 75 | tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT 76 | # set input and output data type 77 | for i in range(network.num_inputs): 78 | input_tensor = network.get_input(i) 79 | input_tensor.dtype = tensor_dtype 80 | for i in range(network.num_outputs): 81 | output_tensor = network.get_output(i) 82 | output_tensor.dtype = tensor_dtype 83 | config.add_optimization_profile(profile) 84 | engine_bytes = builder.build_serialized_network(network, config) 85 | # save trt engine 86 | with open(trt_model, "wb") as f: 87 | f.write(engine_bytes) 88 | logging.info("Succesfully convert onnx to trt...") 89 | 90 | 91 | def export_cosyvoice2_vllm(model, model_path, device): 92 | if os.path.exists(model_path): 93 | return 94 | pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64 95 | vocab_size = model.speech_embedding.num_embeddings 96 | feature_size = model.speech_embedding.embedding_dim 97 | pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to 98 | 99 | dtype = torch.bfloat16 100 | # lm_head 101 | new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True) 102 | with torch.no_grad(): 103 | new_lm_head.weight[:vocab_size] = model.llm_decoder.weight 104 | new_lm_head.bias[:vocab_size] = model.llm_decoder.bias 105 | new_lm_head.weight[vocab_size:] = 0 106 | new_lm_head.bias[vocab_size:] = 0 107 | model.llm.model.lm_head = new_lm_head 108 | new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size) 109 | # embed_tokens 110 | embed_tokens = model.llm.model.model.embed_tokens 111 | with torch.no_grad(): 112 | new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight 113 | new_codec_embed.weight[vocab_size:] = 0 114 | model.llm.model.set_input_embeddings(new_codec_embed) 115 | model.llm.model.to(device) 116 | model.llm.model.to(dtype) 117 | tmp_vocab_size = model.llm.model.config.vocab_size 118 | tmp_tie_embedding = model.llm.model.config.tie_word_embeddings 119 | del model.llm.model.generation_config.eos_token_id 120 | del model.llm.model.config.bos_token_id 121 | del model.llm.model.config.eos_token_id 122 | model.llm.model.config.vocab_size = pad_vocab_size 123 | model.llm.model.config.tie_word_embeddings = False 124 | model.llm.model.config.use_bias = True 125 | model.llm.model.save_pretrained(model_path) 126 | os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path))) 127 | model.llm.model.config.vocab_size = tmp_vocab_size 128 | model.llm.model.config.tie_word_embeddings = tmp_tie_embedding 129 | model.llm.model.set_input_embeddings(embed_tokens) 130 | -------------------------------------------------------------------------------- /cosyvoice/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import regex 17 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 18 | 19 | 20 | # whether contain chinese character 21 | def contains_chinese(text): 22 | return bool(chinese_char_pattern.search(text)) 23 | 24 | 25 | # replace special symbol 26 | def replace_corner_mark(text): 27 | text = text.replace('²', '平方') 28 | text = text.replace('³', '立方') 29 | return text 30 | 31 | 32 | # remove meaningless symbol 33 | def remove_bracket(text): 34 | text = text.replace('(', '').replace(')', '') 35 | text = text.replace('【', '').replace('】', '') 36 | text = text.replace('`', '').replace('`', '') 37 | text = text.replace("——", " ") 38 | return text 39 | 40 | 41 | # spell Arabic numerals 42 | def spell_out_number(text: str, inflect_parser): 43 | new_text = [] 44 | st = None 45 | for i, c in enumerate(text): 46 | if not c.isdigit(): 47 | if st is not None: 48 | num_str = inflect_parser.number_to_words(text[st: i]) 49 | new_text.append(num_str) 50 | st = None 51 | new_text.append(c) 52 | else: 53 | if st is None: 54 | st = i 55 | if st is not None and st < len(text): 56 | num_str = inflect_parser.number_to_words(text[st:]) 57 | new_text.append(num_str) 58 | return ''.join(new_text) 59 | 60 | 61 | # split paragrah logic: 62 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 63 | # 2. cal sentence len according to lang 64 | # 3. split sentence according to puncatation 65 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 66 | def calc_utt_length(_text: str): 67 | if lang == "zh": 68 | return len(_text) 69 | else: 70 | return len(tokenize(_text)) 71 | 72 | def should_merge(_text: str): 73 | if lang == "zh": 74 | return len(_text) < merge_len 75 | else: 76 | return len(tokenize(_text)) < merge_len 77 | 78 | if lang == "zh": 79 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 80 | else: 81 | pounc = ['.', '?', '!', ';', ':'] 82 | if comma_split: 83 | pounc.extend([',', ',']) 84 | 85 | if text[-1] not in pounc: 86 | if lang == "zh": 87 | text += "。" 88 | else: 89 | text += "." 90 | 91 | st = 0 92 | utts = [] 93 | for i, c in enumerate(text): 94 | if c in pounc: 95 | if len(text[st: i]) > 0: 96 | utts.append(text[st: i] + c) 97 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 98 | tmp = utts.pop(-1) 99 | utts.append(tmp + text[i + 1]) 100 | st = i + 2 101 | else: 102 | st = i + 1 103 | 104 | final_utts = [] 105 | cur_utt = "" 106 | for utt in utts: 107 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 108 | final_utts.append(cur_utt) 109 | cur_utt = "" 110 | cur_utt = cur_utt + utt 111 | if len(cur_utt) > 0: 112 | if should_merge(cur_utt) and len(final_utts) != 0: 113 | final_utts[-1] = final_utts[-1] + cur_utt 114 | else: 115 | final_utts.append(cur_utt) 116 | 117 | return final_utts 118 | 119 | 120 | # remove blank between chinese character 121 | def replace_blank(text: str): 122 | out_str = [] 123 | for i, c in enumerate(text): 124 | if c == " ": 125 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 126 | (text[i - 1].isascii() and text[i - 1] != " ")): 127 | out_str.append(c) 128 | else: 129 | out_str.append(c) 130 | return "".join(out_str) 131 | 132 | 133 | def is_only_punctuation(text): 134 | # Regular expression: Match strings that consist only of punctuation marks or are empty. 135 | punctuation_pattern = r'^[\p{P}\p{S}]*$' 136 | return bool(regex.fullmatch(punctuation_pattern, text)) 137 | -------------------------------------------------------------------------------- /cosyvoice/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 6 | loss = 0 7 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 8 | m_DG = torch.median((dr - dg)) 9 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 10 | loss += tau - F.relu(tau - L_rel) 11 | return loss 12 | 13 | 14 | def mel_loss(real_speech, generated_speech, mel_transforms): 15 | loss = 0 16 | for transform in mel_transforms: 17 | mel_r = transform(real_speech) 18 | mel_g = transform(generated_speech) 19 | loss += F.l1_loss(mel_g, mel_r) 20 | return loss 21 | -------------------------------------------------------------------------------- /cosyvoice/utils/losses_dpo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple 4 | 5 | 6 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 7 | loss = 0 8 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 9 | m_DG = torch.median((dr - dg)) 10 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 11 | loss += tau - F.relu(tau - L_rel) 12 | return loss 13 | 14 | 15 | def mel_loss(real_speech, generated_speech, mel_transforms): 16 | loss = 0 17 | for transform in mel_transforms: 18 | mel_r = transform(real_speech) 19 | mel_g = transform(generated_speech) 20 | loss += F.l1_loss(mel_g, mel_r) 21 | return loss 22 | 23 | 24 | class DPOLoss(torch.nn.Module): 25 | """ 26 | DPO Loss 27 | """ 28 | 29 | def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: 30 | super().__init__() 31 | self.beta = beta 32 | self.label_smoothing = label_smoothing 33 | self.ipo = ipo 34 | 35 | def forward( 36 | self, 37 | policy_chosen_logps: torch.Tensor, 38 | policy_rejected_logps: torch.Tensor, 39 | reference_chosen_logps: torch.Tensor, 40 | reference_rejected_logps: torch.Tensor, 41 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 42 | pi_logratios = policy_chosen_logps - policy_rejected_logps 43 | ref_logratios = reference_chosen_logps - reference_rejected_logps 44 | logits = pi_logratios - ref_logratios 45 | if self.ipo: 46 | losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 47 | else: 48 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) 49 | losses = ( 50 | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) 51 | - F.logsigmoid(-self.beta * logits) * self.label_smoothing 52 | ) 53 | loss = losses.mean() 54 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 55 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() 56 | 57 | return loss, chosen_rewards, rejected_rewards 58 | -------------------------------------------------------------------------------- /cosyvoice/vllm/cosyvoice2.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # Adapted from 4 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py 5 | # Copyright 2024 The Qwen team. 6 | # Copyright 2023 The vLLM team. 7 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 8 | # 9 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 10 | # and OPT implementations in this library. It has been modified from its 11 | # original forms to accommodate minor architectural differences compared 12 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | """Inference-only Qwen2 model compatible with HuggingFace weights.""" 26 | from vllm.model_executor.models.qwen2 import * 27 | 28 | 29 | class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): 30 | packed_modules_mapping = { 31 | "qkv_proj": [ 32 | "q_proj", 33 | "k_proj", 34 | "v_proj", 35 | ], 36 | "gate_up_proj": [ 37 | "gate_proj", 38 | "up_proj", 39 | ], 40 | } 41 | 42 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): 43 | super().__init__() 44 | config = vllm_config.model_config.hf_config 45 | quant_config = vllm_config.quant_config 46 | lora_config = vllm_config.lora_config 47 | 48 | self.config = config 49 | self.lora_config = lora_config 50 | 51 | self.quant_config = quant_config 52 | self.model = Qwen2Model(vllm_config=vllm_config, 53 | prefix=maybe_prefix(prefix, "model")) 54 | 55 | if get_pp_group().is_last_rank: 56 | if config.tie_word_embeddings: 57 | self.lm_head = self.model.embed_tokens 58 | else: 59 | self.lm_head = ParallelLMHead(config.vocab_size, 60 | config.hidden_size, 61 | True, 62 | quant_config=quant_config, 63 | prefix=maybe_prefix( 64 | prefix, "lm_head")) 65 | else: 66 | self.lm_head = PPMissingLayer() 67 | 68 | self.logits_processor = LogitsProcessor(config.vocab_size) 69 | 70 | self.make_empty_intermediate_tensors = ( 71 | self.model.make_empty_intermediate_tensors) 72 | 73 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 74 | return self.model.get_input_embeddings(input_ids) 75 | 76 | def forward( 77 | self, 78 | input_ids: torch.Tensor, 79 | positions: torch.Tensor, 80 | intermediate_tensors: Optional[IntermediateTensors] = None, 81 | inputs_embeds: Optional[torch.Tensor] = None, 82 | ) -> Union[torch.Tensor, IntermediateTensors]: 83 | hidden_states = self.model(input_ids, positions, intermediate_tensors, 84 | inputs_embeds) 85 | return hidden_states 86 | 87 | def compute_logits( 88 | self, 89 | hidden_states: torch.Tensor, 90 | sampling_metadata: SamplingMetadata, 91 | ) -> Optional[torch.Tensor]: 92 | logits = self.logits_processor(self.lm_head, hidden_states, 93 | sampling_metadata, self.lm_head.bias) 94 | return logits 95 | 96 | def load_weights(self, weights: Iterable[tuple[str, 97 | torch.Tensor]]) -> set[str]: 98 | loader = AutoWeightsLoader( 99 | self, 100 | skip_prefixes=(["lm_head."] 101 | if self.config.tie_word_embeddings else None), 102 | ) 103 | return loader.load_weights(weights) 104 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 2 | 3 | ARG VENV_NAME="cosyvoice" 4 | ENV VENV=$VENV_NAME 5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 6 | 7 | ENV DEBIAN_FRONTEN=noninteractive 8 | ENV PYTHONUNBUFFERED=1 9 | SHELL ["/bin/bash", "--login", "-c"] 10 | 11 | RUN apt-get update -y --fix-missing 12 | RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \ 13 | apt-get clean && \ 14 | git lfs install 15 | 16 | # ================================================================== 17 | # conda install and conda forge channel as default 18 | # ------------------------------------------------------------------ 19 | # Install miniforge 20 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \ 21 | /bin/bash ~/miniforge.sh -b -p /opt/conda && \ 22 | rm ~/miniforge.sh && \ 23 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 24 | echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \ 25 | echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 26 | echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \ 27 | echo "conda activate ${VENV}" >> $HOME/.bashrc 28 | 29 | ENV PATH /opt/conda/bin:$PATH 30 | 31 | RUN conda config --add channels conda-forge && \ 32 | conda config --set channel_priority strict 33 | # ------------------------------------------------------------------ 34 | # ~conda 35 | # ================================================================== 36 | 37 | RUN conda create -y -n ${VENV} python=3.10 38 | ENV CONDA_DEFAULT_ENV=${VENV} 39 | ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH 40 | 41 | WORKDIR /workspace 42 | 43 | ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS" 44 | 45 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 46 | 47 | RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5 48 | RUN conda activate ${VENV} && cd CosyVoice && \ 49 | pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 50 | 51 | WORKDIR /workspace/CosyVoice 52 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1986] 3 | __set_seed2: !apply:numpy.random.seed [1986] 4 | __set_seed3: !apply:torch.manual_seed [1986] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1986] 6 | 7 | # fixed params 8 | sample_rate: 24000 # 16000 for llm, 24000 for cfm 9 | llm_input_size: 896 10 | llm_output_size: 896 11 | spk_embed_dim: 192 12 | qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN' 13 | 14 | # model params 15 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 16 | # for system/third_party class/function, we do not require this. 17 | llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM 18 | llm_input_size: !ref 19 | llm_output_size: !ref 20 | speech_token_size: 6561 21 | length_normalized_loss: True 22 | lsm_weight: 0 23 | dpo: True 24 | llm: !new:cosyvoice.llm.llm.Qwen2Encoder 25 | pretrain_path: !ref 26 | sampling: !name:cosyvoice.utils.common.ras_sampling 27 | top_p: 0.8 28 | top_k: 25 29 | win_size: 10 30 | tau_r: 0.1 31 | flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec 32 | input_size: 512 33 | output_size: 80 34 | spk_embed_dim: !ref 35 | output_type: 'mel' 36 | vocab_size: 6561 37 | input_frame_rate: 25 38 | only_mask_loss: True 39 | token_mel_ratio: 2 40 | pre_lookahead_len: 3 41 | encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder 42 | output_size: 512 43 | attention_heads: 8 44 | linear_units: 2048 45 | num_blocks: 6 46 | dropout_rate: 0.1 47 | positional_dropout_rate: 0.1 48 | attention_dropout_rate: 0.1 49 | normalize_before: True 50 | input_layer: 'linear' 51 | pos_enc_layer_type: 'rel_pos_espnet' 52 | selfattention_layer_type: 'rel_selfattn' 53 | input_size: 512 54 | use_cnn_module: False 55 | macaron_style: False 56 | decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM 57 | in_channels: 240 58 | n_spks: 1 59 | spk_emb_dim: 80 60 | cfm_params: !new:omegaconf.DictConfig 61 | content: 62 | sigma_min: 1e-06 63 | solver: 'euler' 64 | t_scheduler: 'cosine' 65 | training_cfg_rate: 0.2 66 | inference_cfg_rate: 0.7 67 | reg_loss_type: 'l1' 68 | estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder 69 | in_channels: 320 70 | out_channels: 80 71 | causal: True 72 | channels: [256] 73 | dropout: 0.0 74 | attention_head_dim: 64 75 | n_blocks: 4 76 | num_mid_blocks: 12 77 | num_heads: 8 78 | act_fn: 'gelu' 79 | 80 | hift: !new:cosyvoice.hifigan.generator.HiFTGenerator 81 | in_channels: 80 82 | base_channels: 512 83 | nb_harmonics: 8 84 | sampling_rate: !ref 85 | nsf_alpha: 0.1 86 | nsf_sigma: 0.003 87 | nsf_voiced_threshold: 10 88 | upsample_rates: [8, 5, 3] 89 | upsample_kernel_sizes: [16, 11, 7] 90 | istft_params: 91 | n_fft: 16 92 | hop_len: 4 93 | resblock_kernel_sizes: [3, 7, 11] 94 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 95 | source_resblock_kernel_sizes: [7, 7, 11] 96 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 97 | lrelu_slope: 0.1 98 | audio_limit: 0.99 99 | f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor 100 | num_class: 1 101 | in_channels: 80 102 | cond_channels: 512 103 | 104 | # gan related module 105 | mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram 106 | n_fft: 1024 107 | num_mels: 80 108 | sampling_rate: !ref 109 | hop_size: 256 110 | win_size: 1024 111 | fmin: 0 112 | fmax: null 113 | center: False 114 | hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan 115 | generator: !ref 116 | discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator 117 | mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator 118 | mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator 119 | mel_spec_transform: [ 120 | !ref 121 | ] 122 | 123 | # processor functions 124 | parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener 125 | get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe 126 | multilingual: True 127 | num_languages: 100 128 | language: 'en' 129 | task: 'transcribe' 130 | allowed_special: 'all' 131 | tokenize: !name:cosyvoice.dataset.processor.tokenize 132 | get_tokenizer: !ref 133 | allowed_special: !ref 134 | filter: !name:cosyvoice.dataset.processor.filter 135 | max_length: 40960 136 | min_length: 0 137 | token_max_length: 200 138 | token_min_length: 1 139 | resample: !name:cosyvoice.dataset.processor.resample 140 | resample_rate: !ref 141 | truncate: !name:cosyvoice.dataset.processor.truncate 142 | truncate_length: 24576 # must be a multiplier of hop_size 143 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 144 | n_fft: 1024 145 | num_mels: 80 146 | sampling_rate: !ref 147 | hop_size: 256 148 | win_size: 1024 149 | fmin: 0 150 | fmax: 8000 151 | center: False 152 | compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank 153 | feat_extractor: !ref 154 | compute_f0: !name:cosyvoice.dataset.processor.compute_f0 155 | sample_rate: !ref 156 | hop_size: 256 157 | parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding 158 | normalize: True 159 | shuffle: !name:cosyvoice.dataset.processor.shuffle 160 | shuffle_size: 1000 161 | sort: !name:cosyvoice.dataset.processor.sort 162 | sort_size: 500 # sort_size should be less than shuffle_size 163 | batch: !name:cosyvoice.dataset.processor.batch 164 | batch_type: 'dynamic' 165 | max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g 166 | padding: !name:cosyvoice.dataset.processor.padding 167 | use_spk_embedding: True # change to True during sft 168 | dpo: True 169 | 170 | # dataset processor pipeline 171 | data_pipeline: [ 172 | !ref , 173 | !ref , 174 | !ref , 175 | !ref , 176 | !ref , 177 | !ref , 178 | !ref , 179 | !ref , 180 | !ref , 181 | !ref , 182 | ] 183 | data_pipeline_gan: [ 184 | !ref , 185 | !ref , 186 | !ref , 187 | !ref , 188 | !ref , 189 | !ref , 190 | !ref , 191 | !ref , 192 | !ref , 193 | !ref , 194 | !ref , 195 | !ref , 196 | ] 197 | 198 | # llm flow train conf 199 | train_conf: 200 | optim: adam 201 | optim_conf: 202 | lr: 0.00001 # change to 1e-5 during sft 203 | scheduler: warmuplr # change to constantlr during sft 204 | scheduler_conf: 205 | warmup_steps: 25000 206 | max_epoch: 200 207 | grad_clip: 5 208 | accum_grad: 2 209 | log_interval: 100 210 | save_per_step: -1 211 | 212 | # gan train conf 213 | train_conf_gan: 214 | optim: adam 215 | optim_conf: 216 | lr: 0.0002 # use small lr for gan training 217 | scheduler: constantlr 218 | optim_d: adam 219 | optim_conf_d: 220 | lr: 0.0002 # use small lr for gan training 221 | scheduler_d: constantlr 222 | max_epoch: 200 223 | grad_clip: 5 224 | accum_grad: 1 # in gan training, accum_grad must be 1 225 | log_interval: 100 226 | save_per_step: -1 -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/conf/ds_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 256, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": false 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 2, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | }, 33 | "optimizer": { 34 | "type": "AdamW", 35 | "params": { 36 | "lr": 0.001, 37 | "weight_decay": 0.0001, 38 | "torch_adam": true, 39 | "adam_w_mode": true 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | remove_archive=false 7 | 8 | if [ "$1" == --remove-archive ]; then 9 | remove_archive=true 10 | shift 11 | fi 12 | 13 | if [ $# -ne 3 ]; then 14 | echo "Usage: $0 [--remove-archive] " 15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" 16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other," 18 | echo " train-clean-100, train-clean-360, train-other-500." 19 | exit 1 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1 29 | fi 30 | 31 | part_ok=false 32 | list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500" 33 | for x in $list; do 34 | if [ "$part" == $x ]; then part_ok=true; fi 35 | done 36 | if ! $part_ok; then 37 | echo "$0: expected to be one of $list, but got '$part'" 38 | exit 1 39 | fi 40 | 41 | if [ -z "$url" ]; then 42 | echo "$0: empty URL base." 43 | exit 1 44 | fi 45 | 46 | if [ -f $data/LibriTTS/$part/.complete ]; then 47 | echo "$0: data part $part was already successfully extracted, nothing to do." 48 | exit 0 49 | fi 50 | 51 | 52 | # sizes of the archive files in bytes. This is some older versions. 53 | sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128" 54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of 55 | # things we probably won't download. 56 | sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606" 57 | 58 | if [ -f $data/$part.tar.gz ]; then 59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 60 | size_ok=false 61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done 62 | if ! $size_ok; then 63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 64 | echo "does not equal the size of one of the archives." 65 | rm $data/$part.tar.gz 66 | else 67 | echo "$data/$part.tar.gz exists and appears to be complete." 68 | fi 69 | fi 70 | 71 | if [ ! -f $data/$part.tar.gz ]; then 72 | if ! which wget >/dev/null; then 73 | echo "$0: wget is not installed." 74 | exit 1 75 | fi 76 | full_url=$url/$part.tar.gz 77 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 78 | 79 | if ! wget -P $data --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1 82 | fi 83 | fi 84 | 85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then 86 | echo "$0: error un-tarring archive $data/$part.tar.gz" 87 | exit 1 88 | fi 89 | 90 | touch $data/LibriTTS/$part/.complete 91 | 92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 93 | 94 | if $remove_archive; then 95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 96 | rm $data/$part.tar.gz 97 | fi 98 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import glob 4 | import os 5 | from tqdm import tqdm 6 | 7 | 8 | logger = logging.getLogger() 9 | 10 | 11 | def main(): 12 | wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir))) 13 | 14 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {} 15 | for wav in tqdm(wavs): 16 | txt = wav.replace('.wav', '.normalized.txt') 17 | if not os.path.exists(txt): 18 | logger.warning('{} do not exsist'.format(txt)) 19 | continue 20 | with open(txt) as f: 21 | content = ''.join(l.replace('\n', '') for l in f.readline()) 22 | utt = os.path.basename(wav).replace('.wav', '') 23 | spk = utt.split('_')[0] 24 | utt2wav[utt] = wav 25 | utt2text[utt] = content 26 | utt2spk[utt] = spk 27 | if spk not in spk2utt: 28 | spk2utt[spk] = [] 29 | spk2utt[spk].append(utt) 30 | 31 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 32 | for k, v in utt2wav.items(): 33 | f.write('{} {}\n'.format(k, v)) 34 | with open('{}/text'.format(args.des_dir), 'w') as f: 35 | for k, v in utt2text.items(): 36 | f.write('{} {}\n'.format(k, v)) 37 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f: 38 | for k, v in utt2spk.items(): 39 | f.write('{} {}\n'.format(k, v)) 40 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f: 41 | for k, v in spk2utt.items(): 42 | f.write('{} {}\n'.format(k, ' '.join(v))) 43 | return 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--src_dir', 49 | type=str) 50 | parser.add_argument('--des_dir', 51 | type=str) 52 | args = parser.parse_args() 53 | main() 54 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/path.sh: -------------------------------------------------------------------------------- 1 | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C 2 | export PYTHONIOENCODING=UTF-8 3 | export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH 4 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/60 9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # inference 55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data" 57 | for mode in sft zero_shot; do 58 | python cosyvoice/bin/inference.py --mode $mode \ 59 | --gpu 0 \ 60 | --config conf/cosyvoice.yaml \ 61 | --prompt_data data/test-clean/parquet/data.list \ 62 | --prompt_utt2data data/test-clean/parquet/utt2data.list \ 63 | --tts_text `pwd`/tts_text.json \ 64 | --llm_model $pretrained_model_dir/llm.pt \ 65 | --flow_model $pretrained_model_dir/flow.pt \ 66 | --hifigan_model $pretrained_model_dir/hift.pt \ 67 | --result_dir `pwd`/exp/cosyvoice/test-clean/$mode 68 | done 69 | fi 70 | 71 | # train llm 72 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 73 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 74 | job_id=1986 75 | dist_backend="nccl" 76 | num_workers=2 77 | prefetch=100 78 | train_engine=torch_ddp 79 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 80 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" 81 | if [ $train_engine == 'deepspeed' ]; then 82 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 83 | fi 84 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list 85 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list 86 | for model in llm flow hifigan; do 87 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 88 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ 89 | cosyvoice/bin/train.py \ 90 | --train_engine $train_engine \ 91 | --config conf/cosyvoice.yaml \ 92 | --train_data data/train.data.list \ 93 | --cv_data data/dev.data.list \ 94 | --model $model \ 95 | --checkpoint $pretrained_model_dir/$model.pt \ 96 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \ 97 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \ 98 | --ddp.dist_backend $dist_backend \ 99 | --num_workers ${num_workers} \ 100 | --prefetch ${prefetch} \ 101 | --pin_memory \ 102 | --use_amp \ 103 | --deepspeed_config ./conf/ds_stage2.json \ 104 | --deepspeed.save_states model+optimizer 105 | done 106 | fi 107 | 108 | # average model 109 | average_num=5 110 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 111 | for model in llm flow hifigan; do 112 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 113 | echo "do model average and final checkpoint is $decode_checkpoint" 114 | python cosyvoice/bin/average_model.py \ 115 | --dst_model $decode_checkpoint \ 116 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 117 | --num ${average_num} \ 118 | --val_best 119 | done 120 | fi 121 | 122 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then 123 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 124 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 125 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 126 | fi -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/tts_text.json: -------------------------------------------------------------------------------- 1 | { 2 | "1089_134686_000002_000000": [ 3 | "hello, my name is Jack. What is your name?" 4 | ] 5 | } -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/conf/cosyvoice2.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1986] 3 | __set_seed2: !apply:numpy.random.seed [1986] 4 | __set_seed3: !apply:torch.manual_seed [1986] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1986] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | llm_input_size: 896 10 | llm_output_size: 896 11 | spk_embed_dim: 192 12 | qwen_pretrain_path: '' 13 | token_frame_rate: 25 14 | token_mel_ratio: 2 15 | 16 | # stream related params 17 | chunk_size: 25 # streaming inference chunk size, in token 18 | num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks 19 | 20 | # model params 21 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 22 | # for system/third_party class/function, we do not require this. 23 | llm: !new:cosyvoice.llm.llm.Qwen2LM 24 | llm_input_size: !ref 25 | llm_output_size: !ref 26 | speech_token_size: 6561 27 | length_normalized_loss: True 28 | lsm_weight: 0 29 | mix_ratio: [5, 15] 30 | llm: !new:cosyvoice.llm.llm.Qwen2Encoder 31 | pretrain_path: !ref 32 | sampling: !name:cosyvoice.utils.common.ras_sampling 33 | top_p: 0.8 34 | top_k: 25 35 | win_size: 10 36 | tau_r: 0.1 37 | 38 | flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec 39 | input_size: 512 40 | output_size: 80 41 | spk_embed_dim: !ref 42 | output_type: 'mel' 43 | vocab_size: 6561 44 | input_frame_rate: !ref 45 | only_mask_loss: True 46 | token_mel_ratio: !ref 47 | pre_lookahead_len: 3 48 | encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder 49 | output_size: 512 50 | attention_heads: 8 51 | linear_units: 2048 52 | num_blocks: 6 53 | dropout_rate: 0.1 54 | positional_dropout_rate: 0.1 55 | attention_dropout_rate: 0.1 56 | normalize_before: True 57 | input_layer: 'linear' 58 | pos_enc_layer_type: 'rel_pos_espnet' 59 | selfattention_layer_type: 'rel_selfattn' 60 | input_size: 512 61 | use_cnn_module: False 62 | macaron_style: False 63 | static_chunk_size: !ref 64 | decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM 65 | in_channels: 240 66 | n_spks: 1 67 | spk_emb_dim: 80 68 | cfm_params: !new:omegaconf.DictConfig 69 | content: 70 | sigma_min: 1e-06 71 | solver: 'euler' 72 | t_scheduler: 'cosine' 73 | training_cfg_rate: 0.2 74 | inference_cfg_rate: 0.7 75 | reg_loss_type: 'l1' 76 | estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder 77 | in_channels: 320 78 | out_channels: 80 79 | channels: [256] 80 | dropout: 0.0 81 | attention_head_dim: 64 82 | n_blocks: 4 83 | num_mid_blocks: 12 84 | num_heads: 8 85 | act_fn: 'gelu' 86 | static_chunk_size: !ref * 87 | num_decoding_left_chunks: !ref 88 | 89 | hift: !new:cosyvoice.hifigan.generator.HiFTGenerator 90 | in_channels: 80 91 | base_channels: 512 92 | nb_harmonics: 8 93 | sampling_rate: !ref 94 | nsf_alpha: 0.1 95 | nsf_sigma: 0.003 96 | nsf_voiced_threshold: 10 97 | upsample_rates: [8, 5, 3] 98 | upsample_kernel_sizes: [16, 11, 7] 99 | istft_params: 100 | n_fft: 16 101 | hop_len: 4 102 | resblock_kernel_sizes: [3, 7, 11] 103 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 104 | source_resblock_kernel_sizes: [7, 7, 11] 105 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 106 | lrelu_slope: 0.1 107 | audio_limit: 0.99 108 | f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor 109 | num_class: 1 110 | in_channels: 80 111 | cond_channels: 512 112 | 113 | # gan related module 114 | mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram 115 | n_fft: 1920 116 | num_mels: 80 117 | sampling_rate: !ref 118 | hop_size: 480 119 | win_size: 1920 120 | fmin: 0 121 | fmax: null 122 | center: False 123 | hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan 124 | generator: !ref 125 | discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator 126 | mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator 127 | mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator 128 | mel_spec_transform: [ 129 | !ref 130 | ] 131 | 132 | # processor functions 133 | parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener 134 | get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer 135 | token_path: !ref 136 | skip_special_tokens: True 137 | allowed_special: 'all' 138 | tokenize: !name:cosyvoice.dataset.processor.tokenize 139 | get_tokenizer: !ref 140 | allowed_special: !ref 141 | filter: !name:cosyvoice.dataset.processor.filter 142 | max_length: 40960 143 | min_length: 100 144 | token_max_length: 200 145 | token_min_length: 1 146 | resample: !name:cosyvoice.dataset.processor.resample 147 | resample_rate: !ref 148 | truncate: !name:cosyvoice.dataset.processor.truncate 149 | truncate_length: 24480 # must be a multiplier of hop_size 150 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 151 | n_fft: 1920 152 | num_mels: 80 153 | sampling_rate: !ref 154 | hop_size: 480 155 | win_size: 1920 156 | fmin: 0 157 | fmax: 8000 158 | center: False 159 | compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank 160 | feat_extractor: !ref 161 | token_mel_ratio: 2 162 | compute_f0: !name:cosyvoice.dataset.processor.compute_f0 163 | sample_rate: !ref 164 | hop_size: 480 165 | parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding 166 | normalize: True 167 | shuffle: !name:cosyvoice.dataset.processor.shuffle 168 | shuffle_size: 1000 169 | sort: !name:cosyvoice.dataset.processor.sort 170 | sort_size: 500 # sort_size should be less than shuffle_size 171 | batch: !name:cosyvoice.dataset.processor.batch 172 | batch_type: 'dynamic' 173 | max_frames_in_batch: 2000 174 | padding: !name:cosyvoice.dataset.processor.padding 175 | use_spk_embedding: False # change to True during sft 176 | 177 | 178 | # dataset processor pipeline 179 | data_pipeline: [ 180 | !ref , 181 | !ref , 182 | !ref , 183 | !ref , 184 | !ref , 185 | !ref , 186 | !ref , 187 | !ref , 188 | !ref , 189 | !ref , 190 | ] 191 | data_pipeline_gan: [ 192 | !ref , 193 | !ref , 194 | !ref , 195 | !ref , 196 | !ref , 197 | !ref , 198 | !ref , 199 | !ref , 200 | !ref , 201 | !ref , 202 | !ref , 203 | !ref , 204 | ] 205 | 206 | # llm flow train conf 207 | train_conf: 208 | optim: adam 209 | optim_conf: 210 | lr: 1e-5 # change to 1e-5 during sft 211 | scheduler: constantlr # change to constantlr during sft 212 | scheduler_conf: 213 | warmup_steps: 2500 214 | max_epoch: 200 215 | grad_clip: 5 216 | accum_grad: 2 217 | log_interval: 100 218 | save_per_step: -1 219 | 220 | # gan train conf 221 | train_conf_gan: 222 | optim: adam 223 | optim_conf: 224 | lr: 0.0002 # use small lr for gan training 225 | scheduler: constantlr 226 | optim_d: adam 227 | optim_conf_d: 228 | lr: 0.0002 # use small lr for gan training 229 | scheduler_d: constantlr 230 | max_epoch: 200 231 | grad_clip: 5 232 | accum_grad: 1 # in gan training, accum_grad must be 1 233 | log_interval: 100 234 | save_per_step: -1 -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/conf/ds_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 256, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": false 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 2, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | }, 33 | "optimizer": { 34 | "type": "AdamW", 35 | "params": { 36 | "lr": 0.001, 37 | "weight_decay": 0.0001, 38 | "torch_adam": true, 39 | "adam_w_mode": true 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/local: -------------------------------------------------------------------------------- 1 | ../cosyvoice/local -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/path.sh: -------------------------------------------------------------------------------- 1 | ../cosyvoice/path.sh -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/60 9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # inference 55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data" 57 | # TODO consider remove bin/inference.py, or use similar initilization method as in readme 58 | for mode in sft zero_shot; do 59 | python cosyvoice/bin/inference.py --mode $mode \ 60 | --gpu 0 \ 61 | --config conf/cosyvoice2.yaml \ 62 | --prompt_data data/test-clean/parquet/data.list \ 63 | --prompt_utt2data data/test-clean/parquet/utt2data.list \ 64 | --tts_text `pwd`/tts_text.json \ 65 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ 66 | --llm_model $pretrained_model_dir/llm.pt \ 67 | --flow_model $pretrained_model_dir/flow.pt \ 68 | --hifigan_model $pretrained_model_dir/hift.pt \ 69 | --result_dir `pwd`/exp/cosyvoice/test-clean/$mode 70 | done 71 | fi 72 | 73 | # train llm 74 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 75 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 76 | job_id=1986 77 | dist_backend="nccl" 78 | num_workers=2 79 | prefetch=100 80 | train_engine=torch_ddp 81 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 82 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" 83 | if [ $train_engine == 'deepspeed' ]; then 84 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 85 | fi 86 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list 87 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list 88 | # NOTE will update llm/hift training later 89 | for model in llm flow; do 90 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 91 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ 92 | cosyvoice/bin/train.py \ 93 | --train_engine $train_engine \ 94 | --config conf/cosyvoice2.yaml \ 95 | --train_data data/train.data.list \ 96 | --cv_data data/dev.data.list \ 97 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ 98 | --model $model \ 99 | --checkpoint $pretrained_model_dir/$model.pt \ 100 | --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ 101 | --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \ 102 | --ddp.dist_backend $dist_backend \ 103 | --num_workers ${num_workers} \ 104 | --prefetch ${prefetch} \ 105 | --pin_memory \ 106 | --use_amp \ 107 | --deepspeed_config ./conf/ds_stage2.json \ 108 | --deepspeed.save_states model+optimizer 109 | done 110 | fi 111 | 112 | # average model 113 | average_num=5 114 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 115 | for model in llm flow hifigan; do 116 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 117 | echo "do model average and final checkpoint is $decode_checkpoint" 118 | python cosyvoice/bin/average_model.py \ 119 | --dst_model $decode_checkpoint \ 120 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 121 | --num ${average_num} \ 122 | --val_best 123 | done 124 | fi 125 | 126 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then 127 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 128 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 129 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 130 | fi -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/tts_text.json: -------------------------------------------------------------------------------- 1 | ../cosyvoice/tts_text.json -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/conf: -------------------------------------------------------------------------------- 1 | ../../libritts/cosyvoice/conf -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | remove_archive=false 7 | 8 | if [ "$1" == --remove-archive ]; then 9 | remove_archive=true 10 | shift 11 | fi 12 | 13 | if [ $# -ne 3 ]; then 14 | echo "Usage: $0 [--remove-archive] " 15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" 16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other," 18 | echo " train-clean-100, train-clean-360, train-other-500." 19 | exit 1 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1 29 | fi 30 | 31 | part_ok=false 32 | list="dev_set test_set train_set" 33 | for x in $list; do 34 | if [ "$part" == $x ]; then part_ok=true; fi 35 | done 36 | if ! $part_ok; then 37 | echo "$0: expected to be one of $list, but got '$part'" 38 | exit 1 39 | fi 40 | 41 | if [ -z "$url" ]; then 42 | echo "$0: empty URL base." 43 | exit 1 44 | fi 45 | 46 | if [ -f $data/.$part.complete ]; then 47 | echo "$0: data part $part was already successfully extracted, nothing to do." 48 | exit 0 49 | fi 50 | 51 | 52 | # sizes of the archive files in bytes. This is some older versions. 53 | sizes_old="1035537823 2201936013 52627842921" 54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of 55 | # things we probably won't download. 56 | sizes_new="3886385" 57 | 58 | if [ -f $data/$part.tar.gz ]; then 59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 60 | size_ok=false 61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done 62 | if ! $size_ok; then 63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 64 | echo "does not equal the size of one of the archives." 65 | rm $data/$part.tar.gz 66 | else 67 | echo "$data/$part.tar.gz exists and appears to be complete." 68 | fi 69 | fi 70 | 71 | if [ ! -f $data/$part.tar.gz ]; then 72 | if ! which wget >/dev/null; then 73 | echo "$0: wget is not installed." 74 | exit 1 75 | fi 76 | full_url=$url/$part.tar.gz 77 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 78 | 79 | if ! wget -P $data --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1 82 | fi 83 | fi 84 | 85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then 86 | echo "$0: error un-tarring archive $data/$part.tar.gz" 87 | exit 1 88 | fi 89 | 90 | touch $data/.$part.complete 91 | 92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 93 | 94 | if $remove_archive; then 95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 96 | rm $data/$part.tar.gz 97 | fi 98 | -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from tqdm import tqdm 5 | 6 | 7 | logger = logging.getLogger() 8 | 9 | 10 | def main(): 11 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {} 12 | with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f: 13 | lines = f.readlines()[1:] 14 | lines = [l.split('\t') for l in lines] 15 | for wav, spk, content in tqdm(lines): 16 | wav, spk, content = wav.strip(), spk.strip(), content.strip() 17 | content = content.replace('[FIL]', '') 18 | content = content.replace('[SPK]', '') 19 | wav = os.path.join(args.src_dir, spk, wav) 20 | if not os.path.exists(wav): 21 | continue 22 | utt = os.path.basename(wav).replace('.wav', '') 23 | utt2wav[utt] = wav 24 | utt2text[utt] = content 25 | utt2spk[utt] = spk 26 | if spk not in spk2utt: 27 | spk2utt[spk] = [] 28 | spk2utt[spk].append(utt) 29 | 30 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 31 | for k, v in utt2wav.items(): 32 | f.write('{} {}\n'.format(k, v)) 33 | with open('{}/text'.format(args.des_dir), 'w') as f: 34 | for k, v in utt2text.items(): 35 | f.write('{} {}\n'.format(k, v)) 36 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f: 37 | for k, v in utt2spk.items(): 38 | f.write('{} {}\n'.format(k, v)) 39 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f: 40 | for k, v in spk2utt.items(): 41 | f.write('{} {}\n'.format(k, ' '.join(v))) 42 | return 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--src_dir', 48 | type=str) 49 | parser.add_argument('--des_dir', 50 | type=str) 51 | args = parser.parse_args() 52 | main() 53 | -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/path.sh: -------------------------------------------------------------------------------- 1 | ../../libritts/cosyvoice/path.sh -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/68 9 | data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev_set test_set train_set; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in dev test train; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in dev test train; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in dev test train; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in dev test train; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # inference 55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then 56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data" 57 | for mode in sft zero_shot; do 58 | python cosyvoice/bin/inference.py --mode $mode \ 59 | --gpu 0 \ 60 | --config conf/cosyvoice.yaml \ 61 | --prompt_data data/test/parquet/data.list \ 62 | --prompt_utt2data data/test/parquet/utt2data.list \ 63 | --tts_text `pwd`/tts_text.json \ 64 | --llm_model $pretrained_model_dir/llm.pt \ 65 | --flow_model $pretrained_model_dir/flow.pt \ 66 | --hifigan_model $pretrained_model_dir/hift.pt \ 67 | --result_dir `pwd`/exp/cosyvoice/test/$mode 68 | done 69 | fi 70 | 71 | # train llm 72 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 73 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 74 | job_id=1986 75 | dist_backend="nccl" 76 | num_workers=2 77 | prefetch=100 78 | train_engine=torch_ddp 79 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 80 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" 81 | if [ $train_engine == 'deepspeed' ]; then 82 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 83 | fi 84 | cp data/train/parquet/data.list data/train.data.list 85 | cp data/dev/parquet/data.list data/dev.data.list 86 | for model in llm flow hifigan; do 87 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 88 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ 89 | cosyvoice/bin/train.py \ 90 | --train_engine $train_engine \ 91 | --config conf/cosyvoice.yaml \ 92 | --train_data data/train.data.list \ 93 | --cv_data data/dev.data.list \ 94 | --model $model \ 95 | --checkpoint $pretrained_model_dir/$model.pt \ 96 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \ 97 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \ 98 | --ddp.dist_backend $dist_backend \ 99 | --num_workers ${num_workers} \ 100 | --prefetch ${prefetch} \ 101 | --pin_memory \ 102 | --use_amp \ 103 | --deepspeed_config ./conf/ds_stage2.json \ 104 | --deepspeed.save_states model+optimizer 105 | done 106 | fi 107 | 108 | # average model 109 | average_num=5 110 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 111 | for model in llm flow hifigan; do 112 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 113 | echo "do model average and final checkpoint is $decode_checkpoint" 114 | python cosyvoice/bin/average_model.py \ 115 | --dst_model $decode_checkpoint \ 116 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 117 | --num ${average_num} \ 118 | --val_best 119 | done 120 | fi 121 | 122 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 123 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 124 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 125 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 126 | fi -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/tts_text.json: -------------------------------------------------------------------------------- 1 | { 2 | "38_5718_20170915093303": [ 3 | "我想这出最好歌曲把歌词发到网上请别人帮我作曲急急", 4 | "叫他明天早上差五分儿九点去机场" 5 | ], 6 | "38_5721_20170915091235": [ 7 | "变温室调到零下两度档", 8 | "交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件" 9 | ], 10 | "38_5733_20170915130323": [ 11 | "这是老鹰乐队的一首经典歌曲", 12 | "我急用这段音乐我自己找到一段但是有现场杂音" 13 | ], 14 | "38_5836_20170916221414": [ 15 | "给我播一个陶喆的专辑", 16 | "这套餐好贵呀我发这么多短信贵死了" 17 | ] 18 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684 3 | conformer==0.3.2 4 | deepspeed==0.15.1; sys_platform == 'linux' 5 | diffusers==0.29.0 6 | fastapi==0.115.6 7 | fastapi-cli==0.0.4 8 | gdown==5.1.0 9 | gradio==5.4.0 10 | grpcio==1.57.0 11 | grpcio-tools==1.57.0 12 | hydra-core==1.3.2 13 | HyperPyYAML==1.2.2 14 | inflect==7.3.1 15 | librosa==0.10.2 16 | lightning==2.2.4 17 | matplotlib==3.7.5 18 | modelscope==1.20.0 19 | networkx==3.1 20 | omegaconf==2.3.0 21 | onnx==1.16.0 22 | onnxruntime-gpu==1.18.0; sys_platform == 'linux' 23 | onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' 24 | openai-whisper==20231117 25 | protobuf==4.25 26 | pyarrow==18.1.0 27 | pydantic==2.7.0 28 | pyworld==0.3.4 29 | rich==13.7.1 30 | soundfile==0.12.1 31 | tensorboard==2.14.0 32 | tensorrt-cu12==10.0.1; sys_platform == 'linux' 33 | tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux' 34 | tensorrt-cu12-libs==10.0.1; sys_platform == 'linux' 35 | torch==2.3.1 36 | torchaudio==2.3.1 37 | transformers==4.40.1 38 | uvicorn==0.30.0 39 | WeTextProcessing==1.0.3 40 | wget==3.2 41 | -------------------------------------------------------------------------------- /runtime/python/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | WORKDIR /opt/CosyVoice 5 | 6 | RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list 7 | RUN apt-get update -y 8 | RUN apt-get -y install git unzip git-lfs g++ 9 | RUN git lfs install 10 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 11 | # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed 12 | RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 13 | RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto -------------------------------------------------------------------------------- /runtime/python/fastapi/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import logging 16 | import requests 17 | import torch 18 | import torchaudio 19 | import numpy as np 20 | 21 | 22 | def main(): 23 | url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode) 24 | if args.mode == 'sft': 25 | payload = { 26 | 'tts_text': args.tts_text, 27 | 'spk_id': args.spk_id 28 | } 29 | response = requests.request("GET", url, data=payload, stream=True) 30 | elif args.mode == 'zero_shot': 31 | payload = { 32 | 'tts_text': args.tts_text, 33 | 'prompt_text': args.prompt_text 34 | } 35 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 36 | response = requests.request("GET", url, data=payload, files=files, stream=True) 37 | elif args.mode == 'cross_lingual': 38 | payload = { 39 | 'tts_text': args.tts_text, 40 | } 41 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 42 | response = requests.request("GET", url, data=payload, files=files, stream=True) 43 | else: 44 | payload = { 45 | 'tts_text': args.tts_text, 46 | 'spk_id': args.spk_id, 47 | 'instruct_text': args.instruct_text 48 | } 49 | response = requests.request("GET", url, data=payload, stream=True) 50 | tts_audio = b'' 51 | for r in response.iter_content(chunk_size=16000): 52 | tts_audio += r 53 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 54 | logging.info('save response to {}'.format(args.tts_wav)) 55 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 56 | logging.info('get response') 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--host', 62 | type=str, 63 | default='0.0.0.0') 64 | parser.add_argument('--port', 65 | type=int, 66 | default='50000') 67 | parser.add_argument('--mode', 68 | default='sft', 69 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 70 | help='request mode') 71 | parser.add_argument('--tts_text', 72 | type=str, 73 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 74 | parser.add_argument('--spk_id', 75 | type=str, 76 | default='中文女') 77 | parser.add_argument('--prompt_text', 78 | type=str, 79 | default='希望你以后能够做的比我还好呦。') 80 | parser.add_argument('--prompt_wav', 81 | type=str, 82 | default='../../../asset/zero_shot_prompt.wav') 83 | parser.add_argument('--instruct_text', 84 | type=str, 85 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 86 | Fights with fervor for justice, but struggles with impulsiveness.') 87 | parser.add_argument('--tts_wav', 88 | type=str, 89 | default='demo.wav') 90 | args = parser.parse_args() 91 | prompt_sr, target_sr = 16000, 22050 92 | main() 93 | -------------------------------------------------------------------------------- /runtime/python/fastapi/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import logging 18 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 19 | from fastapi import FastAPI, UploadFile, Form, File 20 | from fastapi.responses import StreamingResponse 21 | from fastapi.middleware.cors import CORSMiddleware 22 | import uvicorn 23 | import numpy as np 24 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 26 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 27 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 28 | from cosyvoice.utils.file_utils import load_wav 29 | 30 | app = FastAPI() 31 | # set cross region allowance 32 | app.add_middleware( 33 | CORSMiddleware, 34 | allow_origins=["*"], 35 | allow_credentials=True, 36 | allow_methods=["*"], 37 | allow_headers=["*"]) 38 | 39 | 40 | def generate_data(model_output): 41 | for i in model_output: 42 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 43 | yield tts_audio 44 | 45 | 46 | @app.get("/inference_sft") 47 | @app.post("/inference_sft") 48 | async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): 49 | model_output = cosyvoice.inference_sft(tts_text, spk_id) 50 | return StreamingResponse(generate_data(model_output)) 51 | 52 | 53 | @app.get("/inference_zero_shot") 54 | @app.post("/inference_zero_shot") 55 | async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()): 56 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 57 | model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) 58 | return StreamingResponse(generate_data(model_output)) 59 | 60 | 61 | @app.get("/inference_cross_lingual") 62 | @app.post("/inference_cross_lingual") 63 | async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): 64 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 65 | model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) 66 | return StreamingResponse(generate_data(model_output)) 67 | 68 | 69 | @app.get("/inference_instruct") 70 | @app.post("/inference_instruct") 71 | async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): 72 | model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) 73 | return StreamingResponse(generate_data(model_output)) 74 | 75 | 76 | @app.get("/inference_instruct2") 77 | @app.post("/inference_instruct2") 78 | async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()): 79 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 80 | model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) 81 | return StreamingResponse(generate_data(model_output)) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--port', 87 | type=int, 88 | default=50000) 89 | parser.add_argument('--model_dir', 90 | type=str, 91 | default='iic/CosyVoice-300M', 92 | help='local path or modelscope repo id') 93 | args = parser.parse_args() 94 | try: 95 | cosyvoice = CosyVoice(args.model_dir) 96 | except Exception: 97 | try: 98 | cosyvoice = CosyVoice2(args.model_dir) 99 | except Exception: 100 | raise TypeError('no valid model_type!') 101 | uvicorn.run(app, host="0.0.0.0", port=args.port) 102 | -------------------------------------------------------------------------------- /runtime/python/grpc/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 18 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 19 | import logging 20 | import argparse 21 | import torchaudio 22 | import cosyvoice_pb2 23 | import cosyvoice_pb2_grpc 24 | import grpc 25 | import torch 26 | import numpy as np 27 | from cosyvoice.utils.file_utils import load_wav 28 | 29 | 30 | def main(): 31 | with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel: 32 | stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel) 33 | request = cosyvoice_pb2.Request() 34 | if args.mode == 'sft': 35 | logging.info('send sft request') 36 | sft_request = cosyvoice_pb2.sftRequest() 37 | sft_request.spk_id = args.spk_id 38 | sft_request.tts_text = args.tts_text 39 | request.sft_request.CopyFrom(sft_request) 40 | elif args.mode == 'zero_shot': 41 | logging.info('send zero_shot request') 42 | zero_shot_request = cosyvoice_pb2.zeroshotRequest() 43 | zero_shot_request.tts_text = args.tts_text 44 | zero_shot_request.prompt_text = args.prompt_text 45 | prompt_speech = load_wav(args.prompt_wav, 16000) 46 | zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 47 | request.zero_shot_request.CopyFrom(zero_shot_request) 48 | elif args.mode == 'cross_lingual': 49 | logging.info('send cross_lingual request') 50 | cross_lingual_request = cosyvoice_pb2.crosslingualRequest() 51 | cross_lingual_request.tts_text = args.tts_text 52 | prompt_speech = load_wav(args.prompt_wav, 16000) 53 | cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 54 | request.cross_lingual_request.CopyFrom(cross_lingual_request) 55 | else: 56 | logging.info('send instruct request') 57 | instruct_request = cosyvoice_pb2.instructRequest() 58 | instruct_request.tts_text = args.tts_text 59 | instruct_request.spk_id = args.spk_id 60 | instruct_request.instruct_text = args.instruct_text 61 | request.instruct_request.CopyFrom(instruct_request) 62 | 63 | response = stub.Inference(request) 64 | tts_audio = b'' 65 | for r in response: 66 | tts_audio += r.tts_audio 67 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 68 | logging.info('save response to {}'.format(args.tts_wav)) 69 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 70 | logging.info('get response') 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--host', 76 | type=str, 77 | default='0.0.0.0') 78 | parser.add_argument('--port', 79 | type=int, 80 | default='50000') 81 | parser.add_argument('--mode', 82 | default='sft', 83 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 84 | help='request mode') 85 | parser.add_argument('--tts_text', 86 | type=str, 87 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 88 | parser.add_argument('--spk_id', 89 | type=str, 90 | default='中文女') 91 | parser.add_argument('--prompt_text', 92 | type=str, 93 | default='希望你以后能够做的比我还好呦。') 94 | parser.add_argument('--prompt_wav', 95 | type=str, 96 | default='../../../asset/zero_shot_prompt.wav') 97 | parser.add_argument('--instruct_text', 98 | type=str, 99 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 100 | Fights with fervor for justice, but struggles with impulsiveness.') 101 | parser.add_argument('--tts_wav', 102 | type=str, 103 | default='demo.wav') 104 | args = parser.parse_args() 105 | prompt_sr, target_sr = 16000, 22050 106 | main() 107 | -------------------------------------------------------------------------------- /runtime/python/grpc/cosyvoice.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cosyvoice; 4 | option go_package = "protos/"; 5 | 6 | service CosyVoice{ 7 | rpc Inference(Request) returns (stream Response) {} 8 | } 9 | 10 | message Request{ 11 | oneof RequestPayload { 12 | sftRequest sft_request = 1; 13 | zeroshotRequest zero_shot_request = 2; 14 | crosslingualRequest cross_lingual_request = 3; 15 | instructRequest instruct_request = 4; 16 | } 17 | } 18 | 19 | message sftRequest{ 20 | string spk_id = 1; 21 | string tts_text = 2; 22 | } 23 | 24 | message zeroshotRequest{ 25 | string tts_text = 1; 26 | string prompt_text = 2; 27 | bytes prompt_audio = 3; 28 | } 29 | 30 | message crosslingualRequest{ 31 | string tts_text = 1; 32 | bytes prompt_audio = 2; 33 | } 34 | 35 | message instructRequest{ 36 | string tts_text = 1; 37 | string spk_id = 2; 38 | string instruct_text = 3; 39 | } 40 | 41 | message Response{ 42 | bytes tts_audio = 1; 43 | } -------------------------------------------------------------------------------- /runtime/python/grpc/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | from concurrent import futures 17 | import argparse 18 | import cosyvoice_pb2 19 | import cosyvoice_pb2_grpc 20 | import logging 21 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 22 | import grpc 23 | import torch 24 | import numpy as np 25 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 26 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 27 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 28 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 29 | 30 | logging.basicConfig(level=logging.DEBUG, 31 | format='%(asctime)s %(levelname)s %(message)s') 32 | 33 | 34 | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): 35 | def __init__(self, args): 36 | try: 37 | self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc) 38 | except Exception: 39 | try: 40 | self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc) 41 | except Exception: 42 | raise TypeError('no valid model_type!') 43 | logging.info('grpc service initialized') 44 | 45 | def Inference(self, request, context): 46 | if request.HasField('sft_request'): 47 | logging.info('get sft inference request') 48 | model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) 49 | elif request.HasField('zero_shot_request'): 50 | logging.info('get zero_shot inference request') 51 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 52 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 53 | model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, 54 | request.zero_shot_request.prompt_text, 55 | prompt_speech_16k) 56 | elif request.HasField('cross_lingual_request'): 57 | logging.info('get cross_lingual inference request') 58 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 59 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 60 | model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) 61 | else: 62 | logging.info('get instruct inference request') 63 | model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, 64 | request.instruct_request.spk_id, 65 | request.instruct_request.instruct_text) 66 | 67 | logging.info('send inference response') 68 | for i in model_output: 69 | response = cosyvoice_pb2.Response() 70 | response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 71 | yield response 72 | 73 | 74 | def main(): 75 | grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) 76 | cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) 77 | grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) 78 | grpcServer.start() 79 | logging.info("server listening on 0.0.0.0:{}".format(args.port)) 80 | grpcServer.wait_for_termination() 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--port', 86 | type=int, 87 | default=50000) 88 | parser.add_argument('--max_conc', 89 | type=int, 90 | default=4) 91 | parser.add_argument('--model_dir', 92 | type=str, 93 | default='iic/CosyVoice-300M', 94 | help='local path or modelscope repo id') 95 | args = parser.parse_args() 96 | main() 97 | -------------------------------------------------------------------------------- /tools/extract_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import onnxruntime 18 | import torch 19 | import torchaudio 20 | import torchaudio.compliance.kaldi as kaldi 21 | from tqdm import tqdm 22 | 23 | 24 | def single_job(utt): 25 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 26 | if sample_rate != 16000: 27 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 28 | feat = kaldi.fbank(audio, 29 | num_mel_bins=80, 30 | dither=0, 31 | sample_frequency=16000) 32 | feat = feat - feat.mean(dim=0, keepdim=True) 33 | embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() 34 | return utt, embedding 35 | 36 | 37 | def main(args): 38 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 39 | utt2embedding, spk2embedding = {}, {} 40 | for future in tqdm(as_completed(all_task)): 41 | utt, embedding = future.result() 42 | utt2embedding[utt] = embedding 43 | spk = utt2spk[utt] 44 | if spk not in spk2embedding: 45 | spk2embedding[spk] = [] 46 | spk2embedding[spk].append(embedding) 47 | for k, v in spk2embedding.items(): 48 | spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() 49 | torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) 50 | torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--dir", type=str) 56 | parser.add_argument("--onnx_path", type=str) 57 | parser.add_argument("--num_thread", type=int, default=8) 58 | args = parser.parse_args() 59 | 60 | utt2wav, utt2spk = {}, {} 61 | with open('{}/wav.scp'.format(args.dir)) as f: 62 | for l in f: 63 | l = l.replace('\n', '').split() 64 | utt2wav[l[0]] = l[1] 65 | with open('{}/utt2spk'.format(args.dir)) as f: 66 | for l in f: 67 | l = l.replace('\n', '').split() 68 | utt2spk[l[0]] = l[1] 69 | 70 | option = onnxruntime.SessionOptions() 71 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 72 | option.intra_op_num_threads = 1 73 | providers = ["CPUExecutionProvider"] 74 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 75 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /tools/extract_speech_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import logging 18 | import torch 19 | from tqdm import tqdm 20 | import onnxruntime 21 | import numpy as np 22 | import torchaudio 23 | import whisper 24 | 25 | 26 | def single_job(utt): 27 | audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile') 28 | if sample_rate != 16000: 29 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 30 | # Convert audio to mono 31 | if audio.shape[0] > 1: 32 | audio = audio.mean(dim=0, keepdim=True) 33 | if audio.shape[1] / 16000 > 30: 34 | logging.warning('do not support extract speech token for audio longer than 30s') 35 | speech_token = [] 36 | else: 37 | feat = whisper.log_mel_spectrogram(audio, n_mels=128) 38 | speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), 39 | ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() 40 | return utt, speech_token 41 | 42 | 43 | def main(args): 44 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 45 | utt2speech_token = {} 46 | for future in tqdm(as_completed(all_task)): 47 | utt, speech_token = future.result() 48 | utt2speech_token[utt] = speech_token 49 | torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--dir", type=str) 55 | parser.add_argument("--onnx_path", type=str) 56 | parser.add_argument("--num_thread", type=int, default=8) 57 | args = parser.parse_args() 58 | 59 | utt2wav = {} 60 | with open('{}/wav.scp'.format(args.dir)) as f: 61 | for l in f: 62 | l = l.replace('\n', '').split() 63 | utt2wav[l[0]] = l[1] 64 | 65 | option = onnxruntime.SessionOptions() 66 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 67 | option.intra_op_num_threads = 1 68 | providers = ["CUDAExecutionProvider"] 69 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 70 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 71 | 72 | main(args) 73 | -------------------------------------------------------------------------------- /tools/make_parquet_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import os 18 | import json 19 | from tqdm import tqdm 20 | import pandas as pd 21 | import multiprocessing 22 | import time 23 | import torch 24 | 25 | 26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): 27 | start_time = time.time() 28 | data_list = [] 29 | for utt in tqdm(utt_list): 30 | data = open(utt2wav[utt], 'rb').read() 31 | data_list.append(data) 32 | wav_list = [utt2wav[utt] for utt in utt_list] 33 | text_list = [utt2text[utt] for utt in utt_list] 34 | spk_list = [utt2spk[utt] for utt in utt_list] 35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list] 36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] 37 | speech_token_list = [utt2speech_token[utt] for utt in utt_list] 38 | 39 | # 保存到parquet,utt2parquet_file,spk2parquet_file 40 | df = pd.DataFrame() 41 | df['utt'] = utt_list 42 | df['wav'] = wav_list 43 | df['audio_data'] = data_list 44 | df['text'] = text_list 45 | df['spk'] = spk_list 46 | df['utt_embedding'] = uttembedding_list 47 | df['spk_embedding'] = spkembedding_list 48 | df['speech_token'] = speech_token_list 49 | df.to_parquet(parquet_file) 50 | with open(utt2parquet_file, 'w') as f: 51 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) 52 | with open(spk2parquet_file, 'w') as f: 53 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) 54 | logging.info('spend time {}'.format(time.time() - start_time)) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--num_utts_per_parquet', 60 | type=int, 61 | default=1000, 62 | help='num utts per parquet') 63 | parser.add_argument('--num_processes', 64 | type=int, 65 | default=1, 66 | help='num processes for make parquets') 67 | parser.add_argument('--src_dir', 68 | type=str) 69 | parser.add_argument('--des_dir', 70 | type=str) 71 | args = parser.parse_args() 72 | 73 | utt2wav, utt2text, utt2spk = {}, {}, {} 74 | with open('{}/wav.scp'.format(args.src_dir)) as f: 75 | for l in f: 76 | l = l.replace('\n', '').split() 77 | utt2wav[l[0]] = l[1] 78 | with open('{}/text'.format(args.src_dir)) as f: 79 | for l in f: 80 | l = l.replace('\n', '').split() 81 | utt2text[l[0]] = ' '.join(l[1:]) 82 | with open('{}/utt2spk'.format(args.src_dir)) as f: 83 | for l in f: 84 | l = l.replace('\n', '').split() 85 | utt2spk[l[0]] = l[1] 86 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) 87 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) 88 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) 89 | utts = list(utt2wav.keys()) 90 | 91 | # Using process pool to speedup 92 | pool = multiprocessing.Pool(processes=args.num_processes) 93 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] 94 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): 95 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) 96 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) 97 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) 98 | parquet_list.append(parquet_file) 99 | utt2parquet_list.append(utt2parquet_file) 100 | spk2parquet_list.append(spk2parquet_file) 101 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) 102 | pool.close() 103 | pool.join() 104 | 105 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ 106 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ 107 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: 108 | for name in parquet_list: 109 | f1.write(name + '\n') 110 | for name in utt2parquet_list: 111 | f2.write(name + '\n') 112 | for name in spk2parquet_list: 113 | f3.write(name + '\n') 114 | -------------------------------------------------------------------------------- /tools/make_parquet_list_dpo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import os 18 | import json 19 | from tqdm import tqdm 20 | import pandas as pd 21 | import multiprocessing 22 | import time 23 | import torch 24 | 25 | 26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): 27 | start_time = time.time() 28 | data_list = [] 29 | for utt in tqdm(utt_list): 30 | data = open(utt2wav[utt], 'rb').read() 31 | data_list.append(data) 32 | wav_list = [utt2wav[utt] for utt in utt_list] 33 | text_list = [utt2text[utt] for utt in utt_list] 34 | spk_list = [utt2spk[utt] for utt in utt_list] 35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list] 36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] 37 | speech_token_list = [utt2speech_token[utt] for utt in utt_list] 38 | if utt2reject_speech_token: 39 | reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] 40 | 41 | # 保存到parquet,utt2parquet_file,spk2parquet_file 42 | df = pd.DataFrame() 43 | df['utt'] = utt_list 44 | df['wav'] = wav_list 45 | df['audio_data'] = data_list 46 | df['text'] = text_list 47 | df['spk'] = spk_list 48 | df['utt_embedding'] = uttembedding_list 49 | df['spk_embedding'] = spkembedding_list 50 | df['speech_token'] = speech_token_list 51 | if utt2reject_speech_token: 52 | df['reject_speech_token'] = reject_speech_token_list 53 | df.to_parquet(parquet_file) 54 | with open(utt2parquet_file, 'w') as f: 55 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) 56 | with open(spk2parquet_file, 'w') as f: 57 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) 58 | logging.info('spend time {}'.format(time.time() - start_time)) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--num_utts_per_parquet', 64 | type=int, 65 | default=1000, 66 | help='num utts per parquet') 67 | parser.add_argument('--num_processes', 68 | type=int, 69 | default=1, 70 | help='num processes for make parquets') 71 | parser.add_argument('--src_dir', 72 | type=str) 73 | parser.add_argument('--des_dir', 74 | type=str) 75 | parser.add_argument('--dpo', 76 | action='store_true', 77 | default=False, 78 | help='Use Direct Preference Optimization') 79 | args = parser.parse_args() 80 | 81 | utt2wav, utt2text, utt2spk = {}, {}, {} 82 | with open('{}/wav.scp'.format(args.src_dir)) as f: 83 | for l in f: 84 | l = l.replace('\n', '').split() 85 | utt2wav[l[0]] = l[1] 86 | with open('{}/text'.format(args.src_dir)) as f: 87 | for l in f: 88 | l = l.replace('\n', '').split() 89 | utt2text[l[0]] = ' '.join(l[1:]) 90 | with open('{}/utt2spk'.format(args.src_dir)) as f: 91 | for l in f: 92 | l = l.replace('\n', '').split() 93 | utt2spk[l[0]] = l[1] 94 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) 95 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) 96 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) 97 | if args.dpo: 98 | utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir)) 99 | else: 100 | utt2reject_speech_token = None 101 | utts = list(utt2wav.keys()) 102 | 103 | # Using process pool to speedup 104 | pool = multiprocessing.Pool(processes=args.num_processes) 105 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] 106 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): 107 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) 108 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) 109 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) 110 | parquet_list.append(parquet_file) 111 | utt2parquet_list.append(utt2parquet_file) 112 | spk2parquet_list.append(spk2parquet_file) 113 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) 114 | pool.close() 115 | pool.join() 116 | 117 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ 118 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ 119 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: 120 | for name in parquet_list: 121 | f1.write(name + '\n') 122 | for name in utt2parquet_list: 123 | f2.write(name + '\n') 124 | for name in spk2parquet_list: 125 | f3.write(name + '\n') 126 | -------------------------------------------------------------------------------- /vllm_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('third_party/Matcha-TTS') 3 | from vllm import ModelRegistry 4 | from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM 5 | ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM) 6 | 7 | from cosyvoice.cli.cosyvoice import CosyVoice2 8 | from cosyvoice.utils.file_utils import load_wav 9 | from cosyvoice.utils.common import set_all_random_seed 10 | from tqdm import tqdm 11 | 12 | def main(): 13 | cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True) 14 | prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) 15 | for i in tqdm(range(100)): 16 | set_all_random_seed(i) 17 | for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): 18 | continue 19 | 20 | if __name__=='__main__': 21 | main() 22 | --------------------------------------------------------------------------------