├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── lint.yml
│ └── stale-issues.yml
├── .gitignore
├── .gitmodules
├── CODE_OF_CONDUCT.md
├── FAQ.md
├── LICENSE
├── README.md
├── asset
├── cross_lingual_prompt.wav
├── dingding.png
└── zero_shot_prompt.wav
├── cosyvoice
├── __init__.py
├── bin
│ ├── average_model.py
│ ├── export_jit.py
│ ├── export_onnx.py
│ ├── inference.py
│ ├── train.py
│ └── train_dpo.py
├── cli
│ ├── __init__.py
│ ├── cosyvoice.py
│ ├── frontend.py
│ └── model.py
├── dataset
│ ├── __init__.py
│ ├── dataset.py
│ ├── processor.py
│ └── processor_dpo.py
├── flow
│ ├── decoder.py
│ ├── flow.py
│ ├── flow_matching.py
│ └── length_regulator.py
├── hifigan
│ ├── discriminator.py
│ ├── f0_predictor.py
│ ├── generator.py
│ └── hifigan.py
├── llm
│ ├── llm.py
│ └── llm_dpo.py
├── tokenizer
│ ├── assets
│ │ └── multilingual_zh_ja_yue_char_del.tiktoken
│ └── tokenizer.py
├── transformer
│ ├── __init__.py
│ ├── activation.py
│ ├── attention.py
│ ├── convolution.py
│ ├── decoder.py
│ ├── decoder_layer.py
│ ├── embedding.py
│ ├── encoder.py
│ ├── encoder_layer.py
│ ├── label_smoothing_loss.py
│ ├── positionwise_feed_forward.py
│ ├── subsampling.py
│ └── upsample_encoder.py
├── utils
│ ├── __init__.py
│ ├── class_utils.py
│ ├── common.py
│ ├── executor.py
│ ├── executor_dpo.py
│ ├── file_utils.py
│ ├── frontend_utils.py
│ ├── losses.py
│ ├── losses_dpo.py
│ ├── mask.py
│ ├── scheduler.py
│ ├── train_utils.py
│ └── train_utils_dpo.py
└── vllm
│ └── cosyvoice2.py
├── docker
└── Dockerfile
├── examples
├── libritts
│ ├── cosyvoice
│ │ ├── conf
│ │ │ ├── cosyvoice.fromscratch.yaml
│ │ │ ├── cosyvoice.yaml
│ │ │ ├── cosyvoice_dpo.yaml
│ │ │ └── ds_stage2.json
│ │ ├── cosyvoice
│ │ ├── local
│ │ │ ├── download_and_untar.sh
│ │ │ └── prepare_data.py
│ │ ├── path.sh
│ │ ├── run.sh
│ │ ├── tools
│ │ └── tts_text.json
│ └── cosyvoice2
│ │ ├── conf
│ │ ├── cosyvoice2.yaml
│ │ └── ds_stage2.json
│ │ ├── cosyvoice
│ │ ├── local
│ │ ├── path.sh
│ │ ├── run.sh
│ │ ├── tools
│ │ └── tts_text.json
└── magicdata-read
│ └── cosyvoice
│ ├── conf
│ ├── cosyvoice
│ ├── local
│ ├── download_and_untar.sh
│ └── prepare_data.py
│ ├── path.sh
│ ├── run.sh
│ ├── tools
│ └── tts_text.json
├── requirements.txt
├── runtime
└── python
│ ├── Dockerfile
│ ├── fastapi
│ ├── client.py
│ └── server.py
│ └── grpc
│ ├── client.py
│ ├── cosyvoice.proto
│ └── server.py
├── tools
├── extract_embedding.py
├── extract_speech_token.py
├── make_parquet_list.py
└── make_parquet_list_dpo.py
├── vllm_example.py
└── webui.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint
2 |
3 | on:
4 | pull_request:
5 | push:
6 |
7 | jobs:
8 | quick-checks:
9 | runs-on: ubuntu-latest
10 | steps:
11 | - name: Fetch CosyVoice
12 | uses: actions/checkout@v1
13 | - name: Checkout PR tip
14 | run: |
15 | set -eux
16 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then
17 | # We are on a PR, so actions/checkout leaves us on a merge commit.
18 | # Check out the actual tip of the branch.
19 | git checkout ${{ github.event.pull_request.head.sha }}
20 | fi
21 | echo ::set-output name=commit_sha::$(git rev-parse HEAD)
22 | id: get_pr_tip
23 | - name: Ensure no tabs
24 | run: |
25 | (! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false))
26 | - name: Ensure no trailing whitespace
27 | run: |
28 | (! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false))
29 |
30 | flake8-py3:
31 | runs-on: ubuntu-latest
32 | steps:
33 | - name: Setup Python
34 | uses: actions/setup-python@v1
35 | with:
36 | python-version: 3.9
37 | architecture: x64
38 | - name: Fetch CosyVoice
39 | uses: actions/checkout@v1
40 | - name: Checkout PR tip
41 | run: |
42 | set -eux
43 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then
44 | # We are on a PR, so actions/checkout leaves us on a merge commit.
45 | # Check out the actual tip of the branch.
46 | git checkout ${{ github.event.pull_request.head.sha }}
47 | fi
48 | echo ::set-output name=commit_sha::$(git rev-parse HEAD)
49 | id: get_pr_tip
50 | - name: Run flake8
51 | run: |
52 | set -eux
53 | pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
54 | flake8 --version
55 | flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
56 | if [ $? != 0 ]; then exit 1; fi
--------------------------------------------------------------------------------
/.github/workflows/stale-issues.yml:
--------------------------------------------------------------------------------
1 | name: Close inactive issues
2 | on:
3 | schedule:
4 | - cron: "30 1 * * *"
5 |
6 | jobs:
7 | close-issues:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | pull-requests: write
12 | steps:
13 | - uses: actions/stale@v5
14 | with:
15 | days-before-issue-stale: 30
16 | days-before-issue-close: 14
17 | stale-issue-label: "stale"
18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity."
19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale."
20 | days-before-pr-stale: -1
21 | days-before-pr-close: -1
22 | repo-token: ${{ secrets.GITHUB_TOKEN }}
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # Visual Studio Code files
7 | .vscode
8 | .vs
9 |
10 | # PyCharm files
11 | .idea
12 |
13 | # Eclipse Project settings
14 | *.*project
15 | .settings
16 |
17 | # Sublime Text settings
18 | *.sublime-workspace
19 | *.sublime-project
20 |
21 | # Editor temporaries
22 | *.swn
23 | *.swo
24 | *.swp
25 | *.swm
26 | *~
27 |
28 | # IPython notebook checkpoints
29 | .ipynb_checkpoints
30 |
31 | # macOS dir files
32 | .DS_Store
33 |
34 | exp
35 | data
36 | raw_wav
37 | tensorboard
38 | **/*build*
39 |
40 | # Clangd files
41 | .cache
42 | compile_commands.json
43 |
44 | # train/inference files
45 | *.wav
46 | *.m4a
47 | *.aac
48 | *.pt
49 | pretrained_models/*
50 | *_pb2_grpc.py
51 | *_pb2.py
52 | *.tar
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/Matcha-TTS"]
2 | path = third_party/Matcha-TTS
3 | url = https://github.com/shivammehta25/Matcha-TTS.git
4 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at mikelei@mobvoi.com. All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | ## ModuleNotFoundError: No module named 'matcha'
2 |
3 | Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
4 |
5 | run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
6 |
7 | ## cannot find resource.zip or cannot unzip resource.zip
8 |
9 | Please make sure you have git-lfs installed. Execute
10 |
11 | ```sh
12 | git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
13 | cd pretrained_models/CosyVoice-ttsfrd/
14 | unzip resource.zip -d .
15 | pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
16 | ```
17 |
--------------------------------------------------------------------------------
/asset/cross_lingual_prompt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/cross_lingual_prompt.wav
--------------------------------------------------------------------------------
/asset/dingding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/dingding.png
--------------------------------------------------------------------------------
/asset/zero_shot_prompt.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/asset/zero_shot_prompt.wav
--------------------------------------------------------------------------------
/cosyvoice/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/bin/average_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu)
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import argparse
18 | import glob
19 |
20 | import yaml
21 | import torch
22 |
23 |
24 | def get_args():
25 | parser = argparse.ArgumentParser(description='average model')
26 | parser.add_argument('--dst_model', required=True, help='averaged model')
27 | parser.add_argument('--src_path',
28 | required=True,
29 | help='src model path for average')
30 | parser.add_argument('--val_best',
31 | action="store_true",
32 | help='averaged model')
33 | parser.add_argument('--num',
34 | default=5,
35 | type=int,
36 | help='nums for averaged model')
37 |
38 | args = parser.parse_args()
39 | print(args)
40 | return args
41 |
42 |
43 | def main():
44 | args = get_args()
45 | val_scores = []
46 | if args.val_best:
47 | yamls = glob.glob('{}/*.yaml'.format(args.src_path))
48 | yamls = [
49 | f for f in yamls
50 | if not (os.path.basename(f).startswith('train')
51 | or os.path.basename(f).startswith('init'))
52 | ]
53 | for y in yamls:
54 | with open(y, 'r') as f:
55 | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
56 | loss = float(dic_yaml['loss_dict']['loss'])
57 | epoch = int(dic_yaml['epoch'])
58 | step = int(dic_yaml['step'])
59 | tag = dic_yaml['tag']
60 | val_scores += [[epoch, step, loss, tag]]
61 | sorted_val_scores = sorted(val_scores,
62 | key=lambda x: x[2],
63 | reverse=False)
64 | print("best val (epoch, step, loss, tag) = " +
65 | str(sorted_val_scores[:args.num]))
66 | path_list = [
67 | args.src_path + '/epoch_{}_whole.pt'.format(score[0])
68 | for score in sorted_val_scores[:args.num]
69 | ]
70 | print(path_list)
71 | avg = {}
72 | num = args.num
73 | assert num == len(path_list)
74 | for path in path_list:
75 | print('Processing {}'.format(path))
76 | states = torch.load(path, map_location=torch.device('cpu'))
77 | for k in states.keys():
78 | if k not in ['step', 'epoch']:
79 | if k not in avg.keys():
80 | avg[k] = states[k].clone()
81 | else:
82 | avg[k] += states[k]
83 | # average
84 | for k in avg.keys():
85 | if avg[k] is not None:
86 | # pytorch 1.6 use true_divide instead of /=
87 | avg[k] = torch.true_divide(avg[k], num)
88 | print('Saving to {}'.format(args.dst_model))
89 | torch.save(avg, args.dst_model)
90 |
91 |
92 | if __name__ == '__main__':
93 | main()
94 |
--------------------------------------------------------------------------------
/cosyvoice/bin/export_jit.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 |
17 | import argparse
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | import os
21 | import sys
22 | import torch
23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24 | sys.path.append('{}/../..'.format(ROOT_DIR))
25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
27 | from cosyvoice.utils.file_utils import logging
28 |
29 |
30 | def get_args():
31 | parser = argparse.ArgumentParser(description='export your model for deployment')
32 | parser.add_argument('--model_dir',
33 | type=str,
34 | default='pretrained_models/CosyVoice-300M',
35 | help='local path')
36 | args = parser.parse_args()
37 | print(args)
38 | return args
39 |
40 |
41 | def get_optimized_script(model, preserved_attrs=[]):
42 | script = torch.jit.script(model)
43 | if preserved_attrs != []:
44 | script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
45 | else:
46 | script = torch.jit.freeze(script)
47 | script = torch.jit.optimize_for_inference(script)
48 | return script
49 |
50 |
51 | def main():
52 | args = get_args()
53 | logging.basicConfig(level=logging.DEBUG,
54 | format='%(asctime)s %(levelname)s %(message)s')
55 |
56 | torch._C._jit_set_fusion_strategy([('STATIC', 1)])
57 | torch._C._jit_set_profiling_mode(False)
58 | torch._C._jit_set_profiling_executor(False)
59 |
60 | try:
61 | model = CosyVoice(args.model_dir)
62 | except Exception:
63 | try:
64 | model = CosyVoice2(args.model_dir)
65 | except Exception:
66 | raise TypeError('no valid model_type!')
67 |
68 | if not isinstance(model, CosyVoice2):
69 | # 1. export llm text_encoder
70 | llm_text_encoder = model.model.llm.text_encoder
71 | script = get_optimized_script(llm_text_encoder)
72 | script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
73 | script = get_optimized_script(llm_text_encoder.half())
74 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
75 | logging.info('successfully export llm_text_encoder')
76 |
77 | # 2. export llm llm
78 | llm_llm = model.model.llm.llm
79 | script = get_optimized_script(llm_llm, ['forward_chunk'])
80 | script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
81 | script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
82 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
83 | logging.info('successfully export llm_llm')
84 |
85 | # 3. export flow encoder
86 | flow_encoder = model.model.flow.encoder
87 | script = get_optimized_script(flow_encoder)
88 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
89 | script = get_optimized_script(flow_encoder.half())
90 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
91 | logging.info('successfully export flow_encoder')
92 | else:
93 | # 3. export flow encoder
94 | flow_encoder = model.model.flow.encoder
95 | script = get_optimized_script(flow_encoder)
96 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
97 | script = get_optimized_script(flow_encoder.half())
98 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
99 | logging.info('successfully export flow_encoder')
100 |
101 |
102 | if __name__ == '__main__':
103 | main()
104 |
--------------------------------------------------------------------------------
/cosyvoice/bin/export_onnx.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from __future__ import print_function
17 |
18 | import argparse
19 | import logging
20 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
21 | import os
22 | import sys
23 | import onnxruntime
24 | import random
25 | import torch
26 | from tqdm import tqdm
27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28 | sys.path.append('{}/../..'.format(ROOT_DIR))
29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
31 | from cosyvoice.utils.file_utils import logging
32 |
33 |
34 | def get_dummy_input(batch_size, seq_len, out_channels, device):
35 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
36 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
37 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
38 | t = torch.rand((batch_size), dtype=torch.float32, device=device)
39 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
40 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
41 | return x, mask, mu, t, spks, cond
42 |
43 |
44 | def get_args():
45 | parser = argparse.ArgumentParser(description='export your model for deployment')
46 | parser.add_argument('--model_dir',
47 | type=str,
48 | default='pretrained_models/CosyVoice-300M',
49 | help='local path')
50 | args = parser.parse_args()
51 | print(args)
52 | return args
53 |
54 |
55 | @torch.no_grad()
56 | def main():
57 | args = get_args()
58 | logging.basicConfig(level=logging.DEBUG,
59 | format='%(asctime)s %(levelname)s %(message)s')
60 |
61 | try:
62 | model = CosyVoice(args.model_dir)
63 | except Exception:
64 | try:
65 | model = CosyVoice2(args.model_dir)
66 | except Exception:
67 | raise TypeError('no valid model_type!')
68 |
69 | # 1. export flow decoder estimator
70 | estimator = model.model.flow.decoder.estimator
71 | estimator.eval()
72 |
73 | device = model.model.device
74 | batch_size, seq_len = 2, 256
75 | out_channels = model.model.flow.decoder.estimator.out_channels
76 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
77 | torch.onnx.export(
78 | estimator,
79 | (x, mask, mu, t, spks, cond),
80 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
81 | export_params=True,
82 | opset_version=18,
83 | do_constant_folding=True,
84 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
85 | output_names=['estimator_out'],
86 | dynamic_axes={
87 | 'x': {2: 'seq_len'},
88 | 'mask': {2: 'seq_len'},
89 | 'mu': {2: 'seq_len'},
90 | 'cond': {2: 'seq_len'},
91 | 'estimator_out': {2: 'seq_len'},
92 | }
93 | )
94 |
95 | # 2. test computation consistency
96 | option = onnxruntime.SessionOptions()
97 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
98 | option.intra_op_num_threads = 1
99 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
100 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
101 | sess_options=option, providers=providers)
102 |
103 | for _ in tqdm(range(10)):
104 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
105 | output_pytorch = estimator(x, mask, mu, t, spks, cond)
106 | ort_inputs = {
107 | 'x': x.cpu().numpy(),
108 | 'mask': mask.cpu().numpy(),
109 | 'mu': mu.cpu().numpy(),
110 | 't': t.cpu().numpy(),
111 | 'spks': spks.cpu().numpy(),
112 | 'cond': cond.cpu().numpy()
113 | }
114 | output_onnx = estimator_onnx.run(None, ort_inputs)[0]
115 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
116 | logging.info('successfully export estimator')
117 |
118 |
119 | if __name__ == "__main__":
120 | main()
121 |
--------------------------------------------------------------------------------
/cosyvoice/bin/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 |
17 | import argparse
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | import os
21 | import torch
22 | from torch.utils.data import DataLoader
23 | import torchaudio
24 | from hyperpyyaml import load_hyperpyyaml
25 | from tqdm import tqdm
26 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
27 | from cosyvoice.dataset.dataset import Dataset
28 |
29 |
30 | def get_args():
31 | parser = argparse.ArgumentParser(description='inference with your model')
32 | parser.add_argument('--config', required=True, help='config file')
33 | parser.add_argument('--prompt_data', required=True, help='prompt data file')
34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file')
35 | parser.add_argument('--tts_text', required=True, help='tts input file')
36 | parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
37 | parser.add_argument('--llm_model', required=True, help='llm model file')
38 | parser.add_argument('--flow_model', required=True, help='flow model file')
39 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file')
40 | parser.add_argument('--gpu',
41 | type=int,
42 | default=-1,
43 | help='gpu id for this rank, -1 for cpu')
44 | parser.add_argument('--mode',
45 | default='sft',
46 | choices=['sft', 'zero_shot'],
47 | help='inference mode')
48 | parser.add_argument('--result_dir', required=True, help='asr result file')
49 | args = parser.parse_args()
50 | print(args)
51 | return args
52 |
53 |
54 | def main():
55 | args = get_args()
56 | logging.basicConfig(level=logging.DEBUG,
57 | format='%(asctime)s %(levelname)s %(message)s')
58 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
59 |
60 | # Init cosyvoice models from configs
61 | use_cuda = args.gpu >= 0 and torch.cuda.is_available()
62 | device = torch.device('cuda' if use_cuda else 'cpu')
63 | try:
64 | with open(args.config, 'r') as f:
65 | configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path})
66 | model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'])
67 | except Exception:
68 | try:
69 | with open(args.config, 'r') as f:
70 | configs = load_hyperpyyaml(f)
71 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
72 | except Exception:
73 | raise TypeError('no valid model_type!')
74 |
75 | model.load(args.llm_model, args.flow_model, args.hifigan_model)
76 |
77 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
78 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
79 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
80 |
81 | sample_rate = configs['sample_rate']
82 | del configs
83 | os.makedirs(args.result_dir, exist_ok=True)
84 | fn = os.path.join(args.result_dir, 'wav.scp')
85 | f = open(fn, 'w')
86 | with torch.no_grad():
87 | for _, batch in tqdm(enumerate(test_data_loader)):
88 | utts = batch["utts"]
89 | assert len(utts) == 1, "inference mode only support batchsize 1"
90 | text_token = batch["text_token"].to(device)
91 | text_token_len = batch["text_token_len"].to(device)
92 | tts_index = batch["tts_index"]
93 | tts_text_token = batch["tts_text_token"].to(device)
94 | tts_text_token_len = batch["tts_text_token_len"].to(device)
95 | speech_token = batch["speech_token"].to(device)
96 | speech_token_len = batch["speech_token_len"].to(device)
97 | speech_feat = batch["speech_feat"].to(device)
98 | speech_feat_len = batch["speech_feat_len"].to(device)
99 | utt_embedding = batch["utt_embedding"].to(device)
100 | spk_embedding = batch["spk_embedding"].to(device)
101 | if args.mode == 'sft':
102 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
103 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding}
104 | else:
105 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len,
106 | 'prompt_text': text_token, 'prompt_text_len': text_token_len,
107 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
108 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
109 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
110 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding}
111 | tts_speeches = []
112 | for model_output in model.tts(**model_input):
113 | tts_speeches.append(model_output['tts_speech'])
114 | tts_speeches = torch.concat(tts_speeches, dim=1)
115 | tts_key = '{}_{}'.format(utts[0], tts_index[0])
116 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key))
117 | torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile')
118 | f.write('{} {}\n'.format(tts_key, tts_fn))
119 | f.flush()
120 | f.close()
121 | logging.info('Result wav.scp saved in {}'.format(fn))
122 |
123 |
124 | if __name__ == '__main__':
125 | main()
126 |
--------------------------------------------------------------------------------
/cosyvoice/bin/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 | import argparse
17 | import datetime
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | from copy import deepcopy
21 | import os
22 | import torch
23 | import torch.distributed as dist
24 | import deepspeed
25 |
26 | from hyperpyyaml import load_hyperpyyaml
27 |
28 | from torch.distributed.elastic.multiprocessing.errors import record
29 |
30 | from cosyvoice.utils.executor import Executor
31 | from cosyvoice.utils.train_utils import (
32 | init_distributed,
33 | init_dataset_and_dataloader,
34 | init_optimizer_and_scheduler,
35 | init_summarywriter, save_model,
36 | wrap_cuda_model, check_modify_and_save_config)
37 |
38 |
39 | def get_args():
40 | parser = argparse.ArgumentParser(description='training your network')
41 | parser.add_argument('--train_engine',
42 | default='torch_ddp',
43 | choices=['torch_ddp', 'deepspeed'],
44 | help='Engine for paralleled training')
45 | parser.add_argument('--model', required=True, help='model which will be trained')
46 | parser.add_argument('--config', required=True, help='config file')
47 | parser.add_argument('--train_data', required=True, help='train data file')
48 | parser.add_argument('--cv_data', required=True, help='cv data file')
49 | parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
50 | parser.add_argument('--checkpoint', help='checkpoint model')
51 | parser.add_argument('--model_dir', required=True, help='save model dir')
52 | parser.add_argument('--tensorboard_dir',
53 | default='tensorboard',
54 | help='tensorboard log dir')
55 | parser.add_argument('--ddp.dist_backend',
56 | dest='dist_backend',
57 | default='nccl',
58 | choices=['nccl', 'gloo'],
59 | help='distributed backend')
60 | parser.add_argument('--num_workers',
61 | default=0,
62 | type=int,
63 | help='num of subprocess workers for reading')
64 | parser.add_argument('--prefetch',
65 | default=100,
66 | type=int,
67 | help='prefetch number')
68 | parser.add_argument('--pin_memory',
69 | action='store_true',
70 | default=False,
71 | help='Use pinned memory buffers used for reading')
72 | parser.add_argument('--use_amp',
73 | action='store_true',
74 | default=False,
75 | help='Use automatic mixed precision training')
76 | parser.add_argument('--deepspeed.save_states',
77 | dest='save_states',
78 | default='model_only',
79 | choices=['model_only', 'model+optimizer'],
80 | help='save model/optimizer states')
81 | parser.add_argument('--timeout',
82 | default=60,
83 | type=int,
84 | help='timeout (in seconds) of cosyvoice_join.')
85 | parser = deepspeed.add_config_arguments(parser)
86 | args = parser.parse_args()
87 | return args
88 |
89 |
90 | @record
91 | def main():
92 | args = get_args()
93 | logging.basicConfig(level=logging.DEBUG,
94 | format='%(asctime)s %(levelname)s %(message)s')
95 | # gan train has some special initialization logic
96 | gan = True if args.model == 'hifigan' else False
97 |
98 | override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
99 | if gan is True:
100 | override_dict.pop('hift')
101 | try:
102 | with open(args.config, 'r') as f:
103 | configs = load_hyperpyyaml(f, overrides={**override_dict, 'qwen_pretrain_path': args.qwen_pretrain_path})
104 | except Exception:
105 | with open(args.config, 'r') as f:
106 | configs = load_hyperpyyaml(f, overrides=override_dict)
107 | if gan is True:
108 | configs['train_conf'] = configs['train_conf_gan']
109 | configs['train_conf'].update(vars(args))
110 |
111 | # Init env for ddp
112 | init_distributed(args)
113 |
114 | # Get dataset & dataloader
115 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
116 | init_dataset_and_dataloader(args, configs, gan)
117 |
118 | # Do some sanity checks and save config to arsg.model_dir
119 | configs = check_modify_and_save_config(args, configs)
120 |
121 | # Tensorboard summary
122 | writer = init_summarywriter(args)
123 |
124 | # load checkpoint
125 | model = configs[args.model]
126 | start_step, start_epoch = 0, -1
127 | if args.checkpoint is not None:
128 | if os.path.exists(args.checkpoint):
129 | state_dict = torch.load(args.checkpoint, map_location='cpu')
130 | model.load_state_dict(state_dict, strict=False)
131 | if 'step' in state_dict:
132 | start_step = state_dict['step']
133 | if 'epoch' in state_dict:
134 | start_epoch = state_dict['epoch']
135 | else:
136 | logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
137 |
138 | # Dispatch model from cpu to gpu
139 | model = wrap_cuda_model(args, model)
140 |
141 | # Get optimizer & scheduler
142 | model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
143 | scheduler.set_step(start_step)
144 | if scheduler_d is not None:
145 | scheduler_d.set_step(start_step)
146 |
147 | # Save init checkpoints
148 | info_dict = deepcopy(configs['train_conf'])
149 | info_dict['step'] = start_step
150 | info_dict['epoch'] = start_epoch
151 | save_model(model, 'init', info_dict)
152 |
153 | # Get executor
154 | executor = Executor(gan=gan)
155 | executor.step = start_step
156 |
157 | # Init scaler, used for pytorch amp mixed precision training
158 | scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
159 | print('start step {} start epoch {}'.format(start_step, start_epoch))
160 | # Start training loop
161 | for epoch in range(start_epoch + 1, info_dict['max_epoch']):
162 | executor.epoch = epoch
163 | train_dataset.set_epoch(epoch)
164 | dist.barrier()
165 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
166 | if gan is True:
167 | executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
168 | writer, info_dict, scaler, group_join)
169 | else:
170 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join)
171 | dist.destroy_process_group(group_join)
172 |
173 |
174 | if __name__ == '__main__':
175 | main()
176 |
--------------------------------------------------------------------------------
/cosyvoice/bin/train_dpo.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import print_function
16 | import argparse
17 | import datetime
18 | import logging
19 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
20 | from copy import deepcopy
21 | import os
22 | import torch
23 | import torch.distributed as dist
24 | import deepspeed
25 |
26 | from hyperpyyaml import load_hyperpyyaml
27 |
28 | from torch.distributed.elastic.multiprocessing.errors import record
29 |
30 | from cosyvoice.utils.executor_dpo import Executor
31 | from cosyvoice.utils.train_utils_dpo import (
32 | init_distributed,
33 | init_dataset_and_dataloader,
34 | init_optimizer_and_scheduler,
35 | init_summarywriter, save_model,
36 | wrap_cuda_model, check_modify_and_save_config)
37 |
38 |
39 | def get_args():
40 | parser = argparse.ArgumentParser(description='training your network')
41 | parser.add_argument('--train_engine',
42 | default='torch_ddp',
43 | choices=['torch_ddp', 'deepspeed'],
44 | help='Engine for paralleled training')
45 | parser.add_argument('--model', required=True, help='model which will be trained')
46 | parser.add_argument('--config', required=True, help='config file')
47 | parser.add_argument('--train_data', required=True, help='train data file')
48 | parser.add_argument('--cv_data', required=True, help='cv data file')
49 | parser.add_argument('--checkpoint', help='checkpoint model')
50 | parser.add_argument('--model_dir', required=True, help='save model dir')
51 | parser.add_argument('--tensorboard_dir',
52 | default='tensorboard',
53 | help='tensorboard log dir')
54 | parser.add_argument('--ddp.dist_backend',
55 | dest='dist_backend',
56 | default='nccl',
57 | choices=['nccl', 'gloo'],
58 | help='distributed backend')
59 | parser.add_argument('--num_workers',
60 | default=0,
61 | type=int,
62 | help='num of subprocess workers for reading')
63 | parser.add_argument('--prefetch',
64 | default=100,
65 | type=int,
66 | help='prefetch number')
67 | parser.add_argument('--pin_memory',
68 | action='store_true',
69 | default=False,
70 | help='Use pinned memory buffers used for reading')
71 | parser.add_argument('--use_amp',
72 | action='store_true',
73 | default=False,
74 | help='Use automatic mixed precision training')
75 | parser.add_argument('--deepspeed.save_states',
76 | dest='save_states',
77 | default='model_only',
78 | choices=['model_only', 'model+optimizer'],
79 | help='save model/optimizer states')
80 | parser.add_argument('--timeout',
81 | default=60,
82 | type=int,
83 | help='timeout (in seconds) of cosyvoice_join.')
84 | parser.add_argument('--dpo',
85 | action='store_true',
86 | default=False,
87 | help='Use Direct Preference Optimization')
88 | parser.add_argument('--beta',
89 | default=0.01,
90 | type=float,
91 | help='beta of dpo training')
92 | parser = deepspeed.add_config_arguments(parser)
93 | args = parser.parse_args()
94 | return args
95 |
96 |
97 | @record
98 | def main():
99 | args = get_args()
100 | logging.basicConfig(level=logging.DEBUG,
101 | format='%(asctime)s %(levelname)s %(message)s')
102 | # gan train has some special initialization logic
103 | gan = True if args.model == 'hifigan' else False
104 |
105 | override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
106 | if gan is True:
107 | override_dict.pop('hift')
108 | with open(args.config, 'r') as f:
109 | configs = load_hyperpyyaml(f, overrides=override_dict)
110 | if gan is True:
111 | configs['train_conf'] = configs['train_conf_gan']
112 | configs['train_conf'].update(vars(args))
113 |
114 | # Init env for ddp
115 | init_distributed(args)
116 |
117 | # Get dataset & dataloader
118 | train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
119 | init_dataset_and_dataloader(args, configs, gan)
120 |
121 | # Do some sanity checks and save config to arsg.model_dir
122 | configs = check_modify_and_save_config(args, configs)
123 |
124 | # Tensorboard summary
125 | writer = init_summarywriter(args)
126 |
127 | # load checkpoint
128 | model = configs[args.model]
129 | ref_model = None
130 | if args.dpo:
131 | ref_model = deepcopy(model)
132 | start_step, start_epoch = 0, -1
133 | if args.checkpoint is not None:
134 | if os.path.exists(args.checkpoint):
135 | state_dict = torch.load(args.checkpoint, map_location='cpu')
136 | model.load_state_dict(state_dict, strict=False)
137 | if args.dpo:
138 | ref_model.load_state_dict(state_dict, strict=False)
139 | if 'step' in state_dict:
140 | start_step = state_dict['step']
141 | if 'epoch' in state_dict:
142 | start_epoch = state_dict['epoch']
143 | else:
144 | logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
145 |
146 | # Dispatch model from cpu to gpu
147 | model = wrap_cuda_model(args, model)
148 | if args.dpo:
149 | ref_model = wrap_cuda_model(args, ref_model)
150 |
151 | # Get optimizer & scheduler
152 | model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
153 | if args.dpo:
154 | ref_model, _, _, _, _ = init_optimizer_and_scheduler(args, configs, ref_model, gan)
155 | scheduler.set_step(start_step)
156 | if scheduler_d is not None:
157 | scheduler_d.set_step(start_step)
158 |
159 | # Save init checkpoints
160 | info_dict = deepcopy(configs['train_conf'])
161 | info_dict['step'] = start_step
162 | info_dict['epoch'] = start_epoch
163 | save_model(model, 'init', info_dict)
164 |
165 | # Get executor
166 | executor = Executor(gan=gan, dpo=args.dpo, beta=args.beta)
167 | executor.step = start_step
168 |
169 | # Init scaler, used for pytorch amp mixed precision training
170 | scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
171 | print('start step {} start epoch {}'.format(start_step, start_epoch))
172 | # Start training loop
173 | for epoch in range(start_epoch + 1, info_dict['max_epoch']):
174 | executor.epoch = epoch
175 | train_dataset.set_epoch(epoch)
176 | dist.barrier()
177 | group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
178 | if gan is True:
179 | executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
180 | writer, info_dict, scaler, group_join)
181 | else:
182 | executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model)
183 | dist.destroy_process_group(group_join)
184 |
185 |
186 | if __name__ == '__main__':
187 | main()
188 |
--------------------------------------------------------------------------------
/cosyvoice/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/cli/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/dataset/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import random
17 | import json
18 | import math
19 | from functools import partial
20 |
21 | import torch
22 | import torch.distributed as dist
23 | from torch.utils.data import IterableDataset
24 | from cosyvoice.utils.file_utils import read_lists, read_json_lists
25 |
26 |
27 | class Processor(IterableDataset):
28 |
29 | def __init__(self, source, f, *args, **kw):
30 | assert callable(f)
31 | self.source = source
32 | self.f = f
33 | self.args = args
34 | self.kw = kw
35 |
36 | def set_epoch(self, epoch):
37 | self.source.set_epoch(epoch)
38 |
39 | def __iter__(self):
40 | """ Return an iterator over the source dataset processed by the
41 | given processor.
42 | """
43 | assert self.source is not None
44 | assert callable(self.f)
45 | return self.f(iter(self.source), *self.args, **self.kw)
46 |
47 | def apply(self, f):
48 | assert callable(f)
49 | return Processor(self, f, *self.args, **self.kw)
50 |
51 |
52 | class DistributedSampler:
53 |
54 | def __init__(self, shuffle=True, partition=True):
55 | self.epoch = -1
56 | self.update()
57 | self.shuffle = shuffle
58 | self.partition = partition
59 |
60 | def update(self):
61 | assert dist.is_available()
62 | if dist.is_initialized():
63 | self.rank = dist.get_rank()
64 | self.world_size = dist.get_world_size()
65 | else:
66 | self.rank = 0
67 | self.world_size = 1
68 | worker_info = torch.utils.data.get_worker_info()
69 | if worker_info is None:
70 | self.worker_id = 0
71 | self.num_workers = 1
72 | else:
73 | self.worker_id = worker_info.id
74 | self.num_workers = worker_info.num_workers
75 | return dict(rank=self.rank,
76 | world_size=self.world_size,
77 | worker_id=self.worker_id,
78 | num_workers=self.num_workers)
79 |
80 | def set_epoch(self, epoch):
81 | self.epoch = epoch
82 |
83 | def sample(self, data):
84 | """ Sample data according to rank/world_size/num_workers
85 |
86 | Args:
87 | data(List): input data list
88 |
89 | Returns:
90 | List: data list after sample
91 | """
92 | data = list(range(len(data)))
93 | # force datalist even
94 | if self.partition:
95 | if self.shuffle:
96 | random.Random(self.epoch).shuffle(data)
97 | if len(data) < self.world_size:
98 | data = data * math.ceil(self.world_size / len(data))
99 | data = data[:self.world_size]
100 | data = data[self.rank::self.world_size]
101 | if len(data) < self.num_workers:
102 | data = data * math.ceil(self.num_workers / len(data))
103 | data = data[:self.num_workers]
104 | data = data[self.worker_id::self.num_workers]
105 | return data
106 |
107 |
108 | class DataList(IterableDataset):
109 |
110 | def __init__(self, lists, shuffle=True, partition=True):
111 | self.lists = lists
112 | self.sampler = DistributedSampler(shuffle, partition)
113 |
114 | def set_epoch(self, epoch):
115 | self.sampler.set_epoch(epoch)
116 |
117 | def __iter__(self):
118 | sampler_info = self.sampler.update()
119 | indexes = self.sampler.sample(self.lists)
120 | for index in indexes:
121 | data = dict(src=self.lists[index])
122 | data.update(sampler_info)
123 | yield data
124 |
125 |
126 | def Dataset(data_list_file,
127 | data_pipeline,
128 | mode='train',
129 | gan=False,
130 | shuffle=True,
131 | partition=True,
132 | tts_file='',
133 | prompt_utt2data=''):
134 | """ Construct dataset from arguments
135 |
136 | We have two shuffle stage in the Dataset. The first is global
137 | shuffle at shards tar/raw file level. The second is global shuffle
138 | at training samples level.
139 |
140 | Args:
141 | data_type(str): raw/shard
142 | tokenizer (BaseTokenizer): tokenizer to tokenize
143 | partition(bool): whether to do data partition in terms of rank
144 | """
145 | assert mode in ['train', 'inference']
146 | lists = read_lists(data_list_file)
147 | if mode == 'inference':
148 | with open(tts_file) as f:
149 | tts_data = json.load(f)
150 | utt2lists = read_json_lists(prompt_utt2data)
151 | # filter unnecessary file in inference mode
152 | lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
153 | dataset = DataList(lists,
154 | shuffle=shuffle,
155 | partition=partition)
156 | if mode == 'inference':
157 | # map partial arg to parquet_opener func in inference mode
158 | data_pipeline[0] = partial(data_pipeline[0], tts_data=tts_data)
159 | if gan is True:
160 | # map partial arg to padding func in gan mode
161 | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan)
162 | for func in data_pipeline:
163 | dataset = Processor(dataset, func, mode=mode)
164 | return dataset
165 |
--------------------------------------------------------------------------------
/cosyvoice/flow/length_regulator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Tuple
15 | import torch.nn as nn
16 | import torch
17 | from torch.nn import functional as F
18 | from cosyvoice.utils.mask import make_pad_mask
19 |
20 |
21 | class InterpolateRegulator(nn.Module):
22 | def __init__(
23 | self,
24 | channels: int,
25 | sampling_ratios: Tuple,
26 | out_channels: int = None,
27 | groups: int = 1,
28 | ):
29 | super().__init__()
30 | self.sampling_ratios = sampling_ratios
31 | out_channels = out_channels or channels
32 | model = nn.ModuleList([])
33 | if len(sampling_ratios) > 0:
34 | for _ in sampling_ratios:
35 | module = nn.Conv1d(channels, channels, 3, 1, 1)
36 | norm = nn.GroupNorm(groups, channels)
37 | act = nn.Mish()
38 | model.extend([module, norm, act])
39 | model.append(
40 | nn.Conv1d(channels, out_channels, 1, 1)
41 | )
42 | self.model = nn.Sequential(*model)
43 |
44 | def forward(self, x, ylens=None):
45 | # x in (B, T, D)
46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
48 | out = self.model(x).transpose(1, 2).contiguous()
49 | olens = ylens
50 | return out * mask, olens
51 |
52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54 | # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
55 | # x in (B, T, D)
56 | if x2.shape[1] > 40:
57 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
58 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
59 | mode='linear')
60 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
61 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
62 | else:
63 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
64 | if x1.shape[1] != 0:
65 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
66 | x = torch.concat([x1, x2], dim=2)
67 | else:
68 | x = x2
69 | out = self.model(x).transpose(1, 2).contiguous()
70 | return out, mel_len1 + mel_len2
71 |
--------------------------------------------------------------------------------
/cosyvoice/hifigan/f0_predictor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import torch.nn as nn
16 | try:
17 | from torch.nn.utils.parametrizations import weight_norm
18 | except ImportError:
19 | from torch.nn.utils import weight_norm
20 |
21 |
22 | class ConvRNNF0Predictor(nn.Module):
23 | def __init__(self,
24 | num_class: int = 1,
25 | in_channels: int = 80,
26 | cond_channels: int = 512
27 | ):
28 | super().__init__()
29 |
30 | self.num_class = num_class
31 | self.condnet = nn.Sequential(
32 | weight_norm(
33 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
34 | ),
35 | nn.ELU(),
36 | weight_norm(
37 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
38 | ),
39 | nn.ELU(),
40 | weight_norm(
41 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
42 | ),
43 | nn.ELU(),
44 | weight_norm(
45 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
46 | ),
47 | nn.ELU(),
48 | weight_norm(
49 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
50 | ),
51 | nn.ELU(),
52 | )
53 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
54 |
55 | def forward(self, x: torch.Tensor) -> torch.Tensor:
56 | x = self.condnet(x)
57 | x = x.transpose(1, 2)
58 | return torch.abs(self.classifier(x).squeeze(-1))
59 |
--------------------------------------------------------------------------------
/cosyvoice/hifigan/hifigan.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
6 | from cosyvoice.utils.losses import tpr_loss, mel_loss
7 |
8 |
9 | class HiFiGan(nn.Module):
10 | def __init__(self, generator, discriminator, mel_spec_transform,
11 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
12 | tpr_loss_weight=1.0, tpr_loss_tau=0.04):
13 | super(HiFiGan, self).__init__()
14 | self.generator = generator
15 | self.discriminator = discriminator
16 | self.mel_spec_transform = mel_spec_transform
17 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
18 | self.feat_match_loss_weight = feat_match_loss_weight
19 | self.tpr_loss_weight = tpr_loss_weight
20 | self.tpr_loss_tau = tpr_loss_tau
21 |
22 | def forward(
23 | self,
24 | batch: dict,
25 | device: torch.device,
26 | ) -> Dict[str, Optional[torch.Tensor]]:
27 | if batch['turn'] == 'generator':
28 | return self.forward_generator(batch, device)
29 | else:
30 | return self.forward_discriminator(batch, device)
31 |
32 | def forward_generator(self, batch, device):
33 | real_speech = batch['speech'].to(device)
34 | pitch_feat = batch['pitch_feat'].to(device)
35 | # 1. calculate generator outputs
36 | generated_speech, generated_f0 = self.generator(batch, device)
37 | # 2. calculate discriminator outputs
38 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
39 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
40 | loss_gen, _ = generator_loss(y_d_gs)
41 | loss_fm = feature_loss(fmap_rs, fmap_gs)
42 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
43 | if self.tpr_loss_weight != 0:
44 | loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
45 | else:
46 | loss_tpr = torch.zeros(1).to(device)
47 | loss_f0 = F.l1_loss(generated_f0, pitch_feat)
48 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
49 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \
50 | self.tpr_loss_weight * loss_tpr + loss_f0
51 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
52 |
53 | def forward_discriminator(self, batch, device):
54 | real_speech = batch['speech'].to(device)
55 | # 1. calculate generator outputs
56 | with torch.no_grad():
57 | generated_speech, generated_f0 = self.generator(batch, device)
58 | # 2. calculate discriminator outputs
59 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
60 | # 3. calculate discriminator losses, tpr losses [Optional]
61 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
62 | if self.tpr_loss_weight != 0:
63 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
64 | else:
65 | loss_tpr = torch.zeros(1).to(device)
66 | loss = loss_disc + self.tpr_loss_weight * loss_tpr
67 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
68 |
--------------------------------------------------------------------------------
/cosyvoice/tokenizer/tokenizer.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import os
3 | from functools import lru_cache
4 | from typing import Optional
5 | import torch
6 | from transformers import AutoTokenizer
7 | from whisper.tokenizer import Tokenizer
8 |
9 | import tiktoken
10 |
11 | LANGUAGES = {
12 | "en": "english",
13 | "zh": "chinese",
14 | "de": "german",
15 | "es": "spanish",
16 | "ru": "russian",
17 | "ko": "korean",
18 | "fr": "french",
19 | "ja": "japanese",
20 | "pt": "portuguese",
21 | "tr": "turkish",
22 | "pl": "polish",
23 | "ca": "catalan",
24 | "nl": "dutch",
25 | "ar": "arabic",
26 | "sv": "swedish",
27 | "it": "italian",
28 | "id": "indonesian",
29 | "hi": "hindi",
30 | "fi": "finnish",
31 | "vi": "vietnamese",
32 | "he": "hebrew",
33 | "uk": "ukrainian",
34 | "el": "greek",
35 | "ms": "malay",
36 | "cs": "czech",
37 | "ro": "romanian",
38 | "da": "danish",
39 | "hu": "hungarian",
40 | "ta": "tamil",
41 | "no": "norwegian",
42 | "th": "thai",
43 | "ur": "urdu",
44 | "hr": "croatian",
45 | "bg": "bulgarian",
46 | "lt": "lithuanian",
47 | "la": "latin",
48 | "mi": "maori",
49 | "ml": "malayalam",
50 | "cy": "welsh",
51 | "sk": "slovak",
52 | "te": "telugu",
53 | "fa": "persian",
54 | "lv": "latvian",
55 | "bn": "bengali",
56 | "sr": "serbian",
57 | "az": "azerbaijani",
58 | "sl": "slovenian",
59 | "kn": "kannada",
60 | "et": "estonian",
61 | "mk": "macedonian",
62 | "br": "breton",
63 | "eu": "basque",
64 | "is": "icelandic",
65 | "hy": "armenian",
66 | "ne": "nepali",
67 | "mn": "mongolian",
68 | "bs": "bosnian",
69 | "kk": "kazakh",
70 | "sq": "albanian",
71 | "sw": "swahili",
72 | "gl": "galician",
73 | "mr": "marathi",
74 | "pa": "punjabi",
75 | "si": "sinhala",
76 | "km": "khmer",
77 | "sn": "shona",
78 | "yo": "yoruba",
79 | "so": "somali",
80 | "af": "afrikaans",
81 | "oc": "occitan",
82 | "ka": "georgian",
83 | "be": "belarusian",
84 | "tg": "tajik",
85 | "sd": "sindhi",
86 | "gu": "gujarati",
87 | "am": "amharic",
88 | "yi": "yiddish",
89 | "lo": "lao",
90 | "uz": "uzbek",
91 | "fo": "faroese",
92 | "ht": "haitian creole",
93 | "ps": "pashto",
94 | "tk": "turkmen",
95 | "nn": "nynorsk",
96 | "mt": "maltese",
97 | "sa": "sanskrit",
98 | "lb": "luxembourgish",
99 | "my": "myanmar",
100 | "bo": "tibetan",
101 | "tl": "tagalog",
102 | "mg": "malagasy",
103 | "as": "assamese",
104 | "tt": "tatar",
105 | "haw": "hawaiian",
106 | "ln": "lingala",
107 | "ha": "hausa",
108 | "ba": "bashkir",
109 | "jw": "javanese",
110 | "su": "sundanese",
111 | "yue": "cantonese",
112 | "minnan": "minnan",
113 | "wuyu": "wuyu",
114 | "dialect": "dialect",
115 | "zh/en": "zh/en",
116 | "en/zh": "en/zh",
117 | }
118 |
119 | # language code lookup by name, with a few language aliases
120 | TO_LANGUAGE_CODE = {
121 | **{language: code for code, language in LANGUAGES.items()},
122 | "burmese": "my",
123 | "valencian": "ca",
124 | "flemish": "nl",
125 | "haitian": "ht",
126 | "letzeburgesch": "lb",
127 | "pushto": "ps",
128 | "panjabi": "pa",
129 | "moldavian": "ro",
130 | "moldovan": "ro",
131 | "sinhalese": "si",
132 | "castilian": "es",
133 | "mandarin": "zh",
134 | }
135 |
136 | AUDIO_EVENT = {
137 | "ASR": "ASR",
138 | "AED": "AED",
139 | "SER": "SER",
140 | "Speech": "Speech",
141 | "/Speech": "/Speech",
142 | "BGM": "BGM",
143 | "/BGM": "/BGM",
144 | "Laughter": "Laughter",
145 | "/Laughter": "/Laughter",
146 | "Applause": "Applause",
147 | "/Applause": "/Applause",
148 | }
149 |
150 | EMOTION = {
151 | "HAPPY": "HAPPY",
152 | "SAD": "SAD",
153 | "ANGRY": "ANGRY",
154 | "NEUTRAL": "NEUTRAL",
155 | }
156 |
157 | TTS_Vocal_Token = {
158 | "TTS/B": "TTS/B",
159 | "TTS/O": "TTS/O",
160 | "TTS/Q": "TTS/Q",
161 | "TTS/A": "TTS/A",
162 | "TTS/CO": "TTS/CO",
163 | "TTS/CL": "TTS/CL",
164 | "TTS/H": "TTS/H",
165 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
166 | }
167 |
168 |
169 | @lru_cache(maxsize=None)
170 | def get_encoding(name: str = "gpt2", num_languages: int = 99):
171 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
172 | ranks = {
173 | base64.b64decode(token): int(rank)
174 | for token, rank in (line.split() for line in open(vocab_path) if line)
175 | }
176 | n_vocab = len(ranks)
177 | special_tokens = {}
178 |
179 | specials = [
180 | "<|endoftext|>",
181 | "<|startoftranscript|>",
182 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
183 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
184 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
185 | "<|translate|>",
186 | "<|transcribe|>",
187 | "<|startoflm|>",
188 | "<|startofprev|>",
189 | "<|nospeech|>",
190 | "<|notimestamps|>",
191 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
192 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
193 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
194 | ]
195 |
196 | for token in specials:
197 | special_tokens[token] = n_vocab
198 | n_vocab += 1
199 |
200 | return tiktoken.Encoding(
201 | name=os.path.basename(vocab_path),
202 | explicit_n_vocab=n_vocab,
203 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
204 | mergeable_ranks=ranks,
205 | special_tokens=special_tokens,
206 | )
207 |
208 |
209 | @lru_cache(maxsize=None)
210 | def get_tokenizer(
211 | multilingual: bool,
212 | *,
213 | num_languages: int = 99,
214 | language: Optional[str] = None,
215 | task: Optional[str] = None, # Literal["transcribe", "translate", None]
216 | ) -> Tokenizer:
217 | if language is not None:
218 | language = language.lower()
219 | if language not in LANGUAGES:
220 | if language in TO_LANGUAGE_CODE:
221 | language = TO_LANGUAGE_CODE[language]
222 | else:
223 | raise ValueError(f"Unsupported language: {language}")
224 |
225 | if multilingual:
226 | encoding_name = "multilingual_zh_ja_yue_char_del"
227 | language = language or "en"
228 | task = task or "transcribe"
229 | else:
230 | encoding_name = "gpt2"
231 | language = None
232 | task = None
233 |
234 | encoding = get_encoding(name=encoding_name, num_languages=num_languages)
235 |
236 | return Tokenizer(
237 | encoding=encoding, num_languages=num_languages, language=language, task=task
238 | )
239 |
240 |
241 | class QwenTokenizer():
242 | def __init__(self, token_path, skip_special_tokens=True):
243 | super().__init__()
244 | # NOTE: non-chat model, all these special tokens keep randomly initialized.
245 | special_tokens = {
246 | 'eos_token': '<|endoftext|>',
247 | 'pad_token': '<|endoftext|>',
248 | 'additional_special_tokens': [
249 | '<|im_start|>', '<|im_end|>', '<|endofprompt|>',
250 | '[breath]', '', '', '[noise]',
251 | '[laughter]', '[cough]', '[clucking]', '[accent]',
252 | '[quick_breath]',
253 | "", "",
254 | "[hissing]", "[sigh]", "[vocalized-noise]",
255 | "[lipsmack]", "[mn]"
256 | ]
257 | }
258 | self.special_tokens = special_tokens
259 | self.tokenizer = AutoTokenizer.from_pretrained(token_path)
260 | self.tokenizer.add_special_tokens(special_tokens)
261 | self.skip_special_tokens = skip_special_tokens
262 |
263 | def encode(self, text, **kwargs):
264 | tokens = self.tokenizer([text], return_tensors="pt")
265 | tokens = tokens["input_ids"][0].cpu().tolist()
266 | return tokens
267 |
268 | def decode(self, tokens):
269 | tokens = torch.tensor(tokens, dtype=torch.int64)
270 | text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
271 | return text
272 |
273 |
274 | @lru_cache(maxsize=None)
275 | def get_qwen_tokenizer(
276 | token_path: str,
277 | skip_special_tokens: bool
278 | ) -> QwenTokenizer:
279 | return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
280 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/transformer/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/transformer/activation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo)
3 | # 2020 Mobvoi Inc (Binbin Zhang)
4 | # 2024 Alibaba Inc (Xiang Lyu)
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | """Swish() activation function for Conformer."""
18 |
19 | import torch
20 | from torch import nn, sin, pow
21 | from torch.nn import Parameter
22 |
23 |
24 | class Swish(torch.nn.Module):
25 | """Construct an Swish object."""
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | """Return Swish activation function."""
29 | return x * torch.sigmoid(x)
30 |
31 |
32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
33 | # LICENSE is in incl_licenses directory.
34 | class Snake(nn.Module):
35 | '''
36 | Implementation of a sine-based periodic activation function
37 | Shape:
38 | - Input: (B, C, T)
39 | - Output: (B, C, T), same shape as the input
40 | Parameters:
41 | - alpha - trainable parameter
42 | References:
43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
44 | https://arxiv.org/abs/2006.08195
45 | Examples:
46 | >>> a1 = snake(256)
47 | >>> x = torch.randn(256)
48 | >>> x = a1(x)
49 | '''
50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
51 | '''
52 | Initialization.
53 | INPUT:
54 | - in_features: shape of the input
55 | - alpha: trainable parameter
56 | alpha is initialized to 1 by default, higher values = higher-frequency.
57 | alpha will be trained along with the rest of your model.
58 | '''
59 | super(Snake, self).__init__()
60 | self.in_features = in_features
61 |
62 | # initialize alpha
63 | self.alpha_logscale = alpha_logscale
64 | if self.alpha_logscale: # log scale alphas initialized to zeros
65 | self.alpha = Parameter(torch.zeros(in_features) * alpha)
66 | else: # linear scale alphas initialized to ones
67 | self.alpha = Parameter(torch.ones(in_features) * alpha)
68 |
69 | self.alpha.requires_grad = alpha_trainable
70 |
71 | self.no_div_by_zero = 0.000000001
72 |
73 | def forward(self, x):
74 | '''
75 | Forward pass of the function.
76 | Applies the function to the input elementwise.
77 | Snake ∶= x + 1/a * sin^2 (xa)
78 | '''
79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
80 | if self.alpha_logscale:
81 | alpha = torch.exp(alpha)
82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
83 |
84 | return x
85 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/convolution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
2 | # 2024 Alibaba Inc (Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | # Modified from ESPnet(https://github.com/espnet/espnet)
16 | """ConvolutionModule definition."""
17 |
18 | from typing import Tuple
19 |
20 | import torch
21 | from torch import nn
22 |
23 |
24 | class ConvolutionModule(nn.Module):
25 | """ConvolutionModule in Conformer model."""
26 |
27 | def __init__(self,
28 | channels: int,
29 | kernel_size: int = 15,
30 | activation: nn.Module = nn.ReLU(),
31 | norm: str = "batch_norm",
32 | causal: bool = False,
33 | bias: bool = True):
34 | """Construct an ConvolutionModule object.
35 | Args:
36 | channels (int): The number of channels of conv layers.
37 | kernel_size (int): Kernel size of conv layers.
38 | causal (int): Whether use causal convolution or not
39 | """
40 | super().__init__()
41 |
42 | self.pointwise_conv1 = nn.Conv1d(
43 | channels,
44 | 2 * channels,
45 | kernel_size=1,
46 | stride=1,
47 | padding=0,
48 | bias=bias,
49 | )
50 | # self.lorder is used to distinguish if it's a causal convolution,
51 | # if self.lorder > 0: it's a causal convolution, the input will be
52 | # padded with self.lorder frames on the left in forward.
53 | # else: it's a symmetrical convolution
54 | if causal:
55 | padding = 0
56 | self.lorder = kernel_size - 1
57 | else:
58 | # kernel_size should be an odd number for none causal convolution
59 | assert (kernel_size - 1) % 2 == 0
60 | padding = (kernel_size - 1) // 2
61 | self.lorder = 0
62 | self.depthwise_conv = nn.Conv1d(
63 | channels,
64 | channels,
65 | kernel_size,
66 | stride=1,
67 | padding=padding,
68 | groups=channels,
69 | bias=bias,
70 | )
71 |
72 | assert norm in ['batch_norm', 'layer_norm']
73 | if norm == "batch_norm":
74 | self.use_layer_norm = False
75 | self.norm = nn.BatchNorm1d(channels)
76 | else:
77 | self.use_layer_norm = True
78 | self.norm = nn.LayerNorm(channels)
79 |
80 | self.pointwise_conv2 = nn.Conv1d(
81 | channels,
82 | channels,
83 | kernel_size=1,
84 | stride=1,
85 | padding=0,
86 | bias=bias,
87 | )
88 | self.activation = activation
89 |
90 | def forward(
91 | self,
92 | x: torch.Tensor,
93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
94 | cache: torch.Tensor = torch.zeros((0, 0, 0)),
95 | ) -> Tuple[torch.Tensor, torch.Tensor]:
96 | """Compute convolution module.
97 | Args:
98 | x (torch.Tensor): Input tensor (#batch, time, channels).
99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
100 | (0, 0, 0) means fake mask.
101 | cache (torch.Tensor): left context cache, it is only
102 | used in causal convolution (#batch, channels, cache_t),
103 | (0, 0, 0) meas fake cache.
104 | Returns:
105 | torch.Tensor: Output tensor (#batch, time, channels).
106 | """
107 | # exchange the temporal dimension and the feature dimension
108 | x = x.transpose(1, 2) # (#batch, channels, time)
109 |
110 | # mask batch padding
111 | if mask_pad.size(2) > 0: # time > 0
112 | x.masked_fill_(~mask_pad, 0.0)
113 |
114 | if self.lorder > 0:
115 | if cache.size(2) == 0: # cache_t == 0
116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
117 | else:
118 | assert cache.size(0) == x.size(0) # equal batch
119 | assert cache.size(1) == x.size(1) # equal channel
120 | x = torch.cat((cache, x), dim=2)
121 | assert (x.size(2) > self.lorder)
122 | new_cache = x[:, :, -self.lorder:]
123 | else:
124 | # It's better we just return None if no cache is required,
125 | # However, for JIT export, here we just fake one tensor instead of
126 | # None.
127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
128 |
129 | # GLU mechanism
130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
132 |
133 | # 1D Depthwise Conv
134 | x = self.depthwise_conv(x)
135 | if self.use_layer_norm:
136 | x = x.transpose(1, 2)
137 | x = self.activation(self.norm(x))
138 | if self.use_layer_norm:
139 | x = x.transpose(1, 2)
140 | x = self.pointwise_conv2(x)
141 | # mask batch padding
142 | if mask_pad.size(2) > 0: # time > 0
143 | x.masked_fill_(~mask_pad, 0.0)
144 |
145 | return x.transpose(1, 2), new_cache
146 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/decoder_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Decoder self-attention layer definition."""
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 | from torch import nn
20 |
21 |
22 | class DecoderLayer(nn.Module):
23 | """Single decoder layer module.
24 |
25 | Args:
26 | size (int): Input dimension.
27 | self_attn (torch.nn.Module): Self-attention module instance.
28 | `MultiHeadedAttention` instance can be used as the argument.
29 | src_attn (torch.nn.Module): Inter-attention module instance.
30 | `MultiHeadedAttention` instance can be used as the argument.
31 | If `None` is passed, Inter-attention is not used, such as
32 | CIF, GPT, and other decoder only model.
33 | feed_forward (torch.nn.Module): Feed-forward module instance.
34 | `PositionwiseFeedForward` instance can be used as the argument.
35 | dropout_rate (float): Dropout rate.
36 | normalize_before (bool):
37 | True: use layer_norm before each sub-block.
38 | False: to use layer_norm after each sub-block.
39 | """
40 |
41 | def __init__(
42 | self,
43 | size: int,
44 | self_attn: nn.Module,
45 | src_attn: Optional[nn.Module],
46 | feed_forward: nn.Module,
47 | dropout_rate: float,
48 | normalize_before: bool = True,
49 | ):
50 | """Construct an DecoderLayer object."""
51 | super().__init__()
52 | self.size = size
53 | self.self_attn = self_attn
54 | self.src_attn = src_attn
55 | self.feed_forward = feed_forward
56 | self.norm1 = nn.LayerNorm(size, eps=1e-5)
57 | self.norm2 = nn.LayerNorm(size, eps=1e-5)
58 | self.norm3 = nn.LayerNorm(size, eps=1e-5)
59 | self.dropout = nn.Dropout(dropout_rate)
60 | self.normalize_before = normalize_before
61 |
62 | def forward(
63 | self,
64 | tgt: torch.Tensor,
65 | tgt_mask: torch.Tensor,
66 | memory: torch.Tensor,
67 | memory_mask: torch.Tensor,
68 | cache: Optional[torch.Tensor] = None
69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
70 | """Compute decoded features.
71 |
72 | Args:
73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
74 | tgt_mask (torch.Tensor): Mask for input tensor
75 | (#batch, maxlen_out).
76 | memory (torch.Tensor): Encoded memory
77 | (#batch, maxlen_in, size).
78 | memory_mask (torch.Tensor): Encoded memory mask
79 | (#batch, maxlen_in).
80 | cache (torch.Tensor): cached tensors.
81 | (#batch, maxlen_out - 1, size).
82 |
83 | Returns:
84 | torch.Tensor: Output tensor (#batch, maxlen_out, size).
85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out).
86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size).
87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in).
88 |
89 | """
90 | residual = tgt
91 | if self.normalize_before:
92 | tgt = self.norm1(tgt)
93 |
94 | if cache is None:
95 | tgt_q = tgt
96 | tgt_q_mask = tgt_mask
97 | else:
98 | # compute only the last frame query keeping dim: max_time_out -> 1
99 | assert cache.shape == (
100 | tgt.shape[0],
101 | tgt.shape[1] - 1,
102 | self.size,
103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
104 | tgt_q = tgt[:, -1:, :]
105 | residual = residual[:, -1:, :]
106 | tgt_q_mask = tgt_mask[:, -1:, :]
107 |
108 | x = residual + self.dropout(
109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
110 | if not self.normalize_before:
111 | x = self.norm1(x)
112 |
113 | if self.src_attn is not None:
114 | residual = x
115 | if self.normalize_before:
116 | x = self.norm2(x)
117 | x = residual + self.dropout(
118 | self.src_attn(x, memory, memory, memory_mask)[0])
119 | if not self.normalize_before:
120 | x = self.norm2(x)
121 |
122 | residual = x
123 | if self.normalize_before:
124 | x = self.norm3(x)
125 | x = residual + self.dropout(self.feed_forward(x))
126 | if not self.normalize_before:
127 | x = self.norm3(x)
128 |
129 | if cache is not None:
130 | x = torch.cat([cache, x], dim=1)
131 |
132 | return x, tgt_mask, memory, memory_mask
133 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/label_smoothing_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Label smoothing module."""
16 |
17 | import torch
18 | from torch import nn
19 |
20 |
21 | class LabelSmoothingLoss(nn.Module):
22 | """Label-smoothing loss.
23 |
24 | In a standard CE loss, the label's data distribution is:
25 | [0,1,2] ->
26 | [
27 | [1.0, 0.0, 0.0],
28 | [0.0, 1.0, 0.0],
29 | [0.0, 0.0, 1.0],
30 | ]
31 |
32 | In the smoothing version CE Loss,some probabilities
33 | are taken from the true label prob (1.0) and are divided
34 | among other labels.
35 |
36 | e.g.
37 | smoothing=0.1
38 | [0,1,2] ->
39 | [
40 | [0.9, 0.05, 0.05],
41 | [0.05, 0.9, 0.05],
42 | [0.05, 0.05, 0.9],
43 | ]
44 |
45 | Args:
46 | size (int): the number of class
47 | padding_idx (int): padding class id which will be ignored for loss
48 | smoothing (float): smoothing rate (0.0 means the conventional CE)
49 | normalize_length (bool):
50 | normalize loss by sequence length if True
51 | normalize loss by batch size if False
52 | """
53 |
54 | def __init__(self,
55 | size: int,
56 | padding_idx: int,
57 | smoothing: float,
58 | normalize_length: bool = False):
59 | """Construct an LabelSmoothingLoss object."""
60 | super(LabelSmoothingLoss, self).__init__()
61 | self.criterion = nn.KLDivLoss(reduction="none")
62 | self.padding_idx = padding_idx
63 | self.confidence = 1.0 - smoothing
64 | self.smoothing = smoothing
65 | self.size = size
66 | self.normalize_length = normalize_length
67 |
68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
69 | """Compute loss between x and target.
70 |
71 | The model outputs and data labels tensors are flatten to
72 | (batch*seqlen, class) shape and a mask is applied to the
73 | padding part which should not be calculated for loss.
74 |
75 | Args:
76 | x (torch.Tensor): prediction (batch, seqlen, class)
77 | target (torch.Tensor):
78 | target signal masked with self.padding_id (batch, seqlen)
79 | Returns:
80 | loss (torch.Tensor) : The KL loss, scalar float value
81 | """
82 | assert x.size(2) == self.size
83 | batch_size = x.size(0)
84 | x = x.view(-1, self.size)
85 | target = target.view(-1)
86 | # use zeros_like instead of torch.no_grad() for true_dist,
87 | # since no_grad() can not be exported by JIT
88 | true_dist = torch.zeros_like(x)
89 | true_dist.fill_(self.smoothing / (self.size - 1))
90 | ignore = target == self.padding_idx # (B,)
91 | total = len(target) - ignore.sum().item()
92 | target = target.masked_fill(ignore, 0) # avoid -1 index
93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
95 | denom = total if self.normalize_length else batch_size
96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
97 |
--------------------------------------------------------------------------------
/cosyvoice/transformer/positionwise_feed_forward.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019 Shigeki Karita
2 | # 2020 Mobvoi Inc (Binbin Zhang)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """Positionwise feed forward layer definition."""
16 |
17 | import torch
18 |
19 |
20 | class PositionwiseFeedForward(torch.nn.Module):
21 | """Positionwise feed forward layer.
22 |
23 | FeedForward are appied on each position of the sequence.
24 | The output dim is same with the input dim.
25 |
26 | Args:
27 | idim (int): Input dimenstion.
28 | hidden_units (int): The number of hidden units.
29 | dropout_rate (float): Dropout rate.
30 | activation (torch.nn.Module): Activation function
31 | """
32 |
33 | def __init__(
34 | self,
35 | idim: int,
36 | hidden_units: int,
37 | dropout_rate: float,
38 | activation: torch.nn.Module = torch.nn.ReLU(),
39 | ):
40 | """Construct a PositionwiseFeedForward object."""
41 | super(PositionwiseFeedForward, self).__init__()
42 | self.w_1 = torch.nn.Linear(idim, hidden_units)
43 | self.activation = activation
44 | self.dropout = torch.nn.Dropout(dropout_rate)
45 | self.w_2 = torch.nn.Linear(hidden_units, idim)
46 |
47 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
48 | """Forward function.
49 |
50 | Args:
51 | xs: input tensor (B, L, D)
52 | Returns:
53 | output tensor, (B, L, D)
54 | """
55 | return self.w_2(self.dropout(self.activation(self.w_1(xs))))
56 |
57 |
58 | class MoEFFNLayer(torch.nn.Module):
59 | """
60 | Mixture of expert with Positionwise feed forward layer
61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
62 | The output dim is same with the input dim.
63 |
64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
66 | Args:
67 | n_expert: number of expert.
68 | n_expert_per_token: The actual number of experts used for each frame
69 | idim (int): Input dimenstion.
70 | hidden_units (int): The number of hidden units.
71 | dropout_rate (float): Dropout rate.
72 | activation (torch.nn.Module): Activation function
73 | """
74 |
75 | def __init__(
76 | self,
77 | n_expert: int,
78 | n_expert_per_token: int,
79 | idim: int,
80 | hidden_units: int,
81 | dropout_rate: float,
82 | activation: torch.nn.Module = torch.nn.ReLU(),
83 | ):
84 | super(MoEFFNLayer, self).__init__()
85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False)
86 | self.experts = torch.nn.ModuleList(
87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate,
88 | activation) for _ in range(n_expert))
89 | self.n_expert_per_token = n_expert_per_token
90 |
91 | def forward(self, xs: torch.Tensor) -> torch.Tensor:
92 | """Foward function.
93 | Args:
94 | xs: input tensor (B, L, D)
95 | Returns:
96 | output tensor, (B, L, D)
97 |
98 | """
99 | B, L, D = xs.size(
100 | ) # batch size, sequence length, embedding dimension (idim)
101 | xs = xs.view(-1, D) # (B*L, D)
102 | router = self.gate(xs) # (B*L, n_expert)
103 | logits, indices = torch.topk(
104 | router, self.n_expert_per_token
105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert)
106 | weights = torch.nn.functional.softmax(
107 | logits, dim=1,
108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
109 | output = torch.zeros_like(xs) # (B*L, D)
110 | for i, expert in enumerate(self.experts):
111 | mask = indices == i
112 | batch_idx, ith_expert = torch.where(mask)
113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
114 | xs[batch_idx])
115 | return output.view(B, L, D)
116 |
--------------------------------------------------------------------------------
/cosyvoice/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FunAudioLLM/CosyVoice/fca1df3ea734fd393b8fc0d793f0461e3757d858/cosyvoice/utils/__init__.py
--------------------------------------------------------------------------------
/cosyvoice/utils/class_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright [2023-11-28]
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import torch
16 |
17 | from cosyvoice.transformer.activation import Swish
18 | from cosyvoice.transformer.subsampling import (
19 | LinearNoSubsampling,
20 | EmbedinigNoSubsampling,
21 | Conv1dSubsampling2,
22 | Conv2dSubsampling4,
23 | Conv2dSubsampling6,
24 | Conv2dSubsampling8,
25 | )
26 | from cosyvoice.transformer.embedding import (PositionalEncoding,
27 | RelPositionalEncoding,
28 | WhisperPositionalEncoding,
29 | LearnablePositionalEncoding,
30 | NoPositionalEncoding)
31 | from cosyvoice.transformer.attention import (MultiHeadedAttention,
32 | RelPositionMultiHeadedAttention)
33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
35 | from cosyvoice.llm.llm import TransformerLM, Qwen2LM
36 | from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec
37 | from cosyvoice.hifigan.generator import HiFTGenerator
38 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model
39 |
40 |
41 | COSYVOICE_ACTIVATION_CLASSES = {
42 | "hardtanh": torch.nn.Hardtanh,
43 | "tanh": torch.nn.Tanh,
44 | "relu": torch.nn.ReLU,
45 | "selu": torch.nn.SELU,
46 | "swish": getattr(torch.nn, "SiLU", Swish),
47 | "gelu": torch.nn.GELU,
48 | }
49 |
50 | COSYVOICE_SUBSAMPLE_CLASSES = {
51 | "linear": LinearNoSubsampling,
52 | "linear_legacy": LegacyLinearNoSubsampling,
53 | "embed": EmbedinigNoSubsampling,
54 | "conv1d2": Conv1dSubsampling2,
55 | "conv2d": Conv2dSubsampling4,
56 | "conv2d6": Conv2dSubsampling6,
57 | "conv2d8": Conv2dSubsampling8,
58 | 'paraformer_dummy': torch.nn.Identity
59 | }
60 |
61 | COSYVOICE_EMB_CLASSES = {
62 | "embed": PositionalEncoding,
63 | "abs_pos": PositionalEncoding,
64 | "rel_pos": RelPositionalEncoding,
65 | "rel_pos_espnet": EspnetRelPositionalEncoding,
66 | "no_pos": NoPositionalEncoding,
67 | "abs_pos_whisper": WhisperPositionalEncoding,
68 | "embed_learnable_pe": LearnablePositionalEncoding,
69 | }
70 |
71 | COSYVOICE_ATTENTION_CLASSES = {
72 | "selfattn": MultiHeadedAttention,
73 | "rel_selfattn": RelPositionMultiHeadedAttention,
74 | }
75 |
76 |
77 | def get_model_type(configs):
78 | # NOTE CosyVoice2Model inherits CosyVoiceModel
79 | if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
80 | return CosyVoiceModel
81 | if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
82 | return CosyVoice2Model
83 | raise TypeError('No valid model type found!')
84 |
--------------------------------------------------------------------------------
/cosyvoice/utils/common.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # Modified from ESPnet(https://github.com/espnet/espnet)
17 | """Unility functions for Transformer."""
18 |
19 | import queue
20 | import random
21 | from typing import List
22 |
23 | import numpy as np
24 | import torch
25 |
26 | IGNORE_ID = -1
27 |
28 |
29 | def pad_list(xs: List[torch.Tensor], pad_value: int):
30 | """Perform padding for the list of tensors.
31 |
32 | Args:
33 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
34 | pad_value (float): Value for padding.
35 |
36 | Returns:
37 | Tensor: Padded tensor (B, Tmax, `*`).
38 |
39 | Examples:
40 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
41 | >>> x
42 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
43 | >>> pad_list(x, 0)
44 | tensor([[1., 1., 1., 1.],
45 | [1., 1., 0., 0.],
46 | [1., 0., 0., 0.]])
47 |
48 | """
49 | max_len = max([len(item) for item in xs])
50 | batchs = len(xs)
51 | ndim = xs[0].ndim
52 | if ndim == 1:
53 | pad_res = torch.zeros(batchs,
54 | max_len,
55 | dtype=xs[0].dtype,
56 | device=xs[0].device)
57 | elif ndim == 2:
58 | pad_res = torch.zeros(batchs,
59 | max_len,
60 | xs[0].shape[1],
61 | dtype=xs[0].dtype,
62 | device=xs[0].device)
63 | elif ndim == 3:
64 | pad_res = torch.zeros(batchs,
65 | max_len,
66 | xs[0].shape[1],
67 | xs[0].shape[2],
68 | dtype=xs[0].dtype,
69 | device=xs[0].device)
70 | else:
71 | raise ValueError(f"Unsupported ndim: {ndim}")
72 | pad_res.fill_(pad_value)
73 | for i in range(batchs):
74 | pad_res[i, :len(xs[i])] = xs[i]
75 | return pad_res
76 |
77 |
78 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
79 | ignore_label: int) -> torch.Tensor:
80 | """Calculate accuracy.
81 |
82 | Args:
83 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
84 | pad_targets (LongTensor): Target label tensors (B, Lmax).
85 | ignore_label (int): Ignore label id.
86 |
87 | Returns:
88 | torch.Tensor: Accuracy value (0.0 - 1.0).
89 |
90 | """
91 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
92 | pad_outputs.size(1)).argmax(2)
93 | mask = pad_targets != ignore_label
94 | numerator = torch.sum(
95 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
96 | denominator = torch.sum(mask)
97 | return (numerator / denominator).detach()
98 |
99 |
100 | def get_padding(kernel_size, dilation=1):
101 | return int((kernel_size * dilation - dilation) / 2)
102 |
103 |
104 | def init_weights(m, mean=0.0, std=0.01):
105 | classname = m.__class__.__name__
106 | if classname.find("Conv") != -1:
107 | m.weight.data.normal_(mean, std)
108 |
109 |
110 | # Repetition Aware Sampling in VALL-E 2
111 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
112 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
113 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
114 | if rep_num >= win_size * tau_r:
115 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
116 | return top_ids
117 |
118 |
119 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
120 | prob, indices = [], []
121 | cum_prob = 0.0
122 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
123 | for i in range(len(sorted_idx)):
124 | # sampling both top-p and numbers.
125 | if cum_prob < top_p and len(prob) < top_k:
126 | cum_prob += sorted_value[i]
127 | prob.append(sorted_value[i])
128 | indices.append(sorted_idx[i])
129 | else:
130 | break
131 | prob = torch.tensor(prob).to(weighted_scores)
132 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
133 | top_ids = indices[prob.multinomial(1, replacement=True)]
134 | return top_ids
135 |
136 |
137 | def random_sampling(weighted_scores, decoded_tokens, sampling):
138 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
139 | return top_ids
140 |
141 |
142 | def fade_in_out(fade_in_mel, fade_out_mel, window):
143 | device = fade_in_mel.device
144 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
145 | mel_overlap_len = int(window.shape[0] / 2)
146 | if fade_in_mel.device == torch.device('cpu'):
147 | fade_in_mel = fade_in_mel.clone()
148 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
149 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
150 | return fade_in_mel.to(device)
151 |
152 |
153 | def set_all_random_seed(seed):
154 | random.seed(seed)
155 | np.random.seed(seed)
156 | torch.manual_seed(seed)
157 | torch.cuda.manual_seed_all(seed)
158 |
159 |
160 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
161 | assert mask.dtype == torch.bool
162 | assert dtype in [torch.float32, torch.bfloat16, torch.float16]
163 | mask = mask.to(dtype)
164 | # attention mask bias
165 | # NOTE(Mddct): torch.finfo jit issues
166 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
167 | mask = (1.0 - mask) * -1.0e+10
168 | return mask
169 |
170 |
171 | class TrtContextWrapper:
172 | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
173 | self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
174 | self.trt_engine = trt_engine
175 | for _ in range(trt_concurrent):
176 | trt_context = trt_engine.create_execution_context()
177 | trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
178 | assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
179 | self.trt_context_pool.put([trt_context, trt_stream])
180 | assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
181 |
182 | def acquire_estimator(self):
183 | return self.trt_context_pool.get(), self.trt_engine
184 |
185 | def release_estimator(self, context, stream):
186 | self.trt_context_pool.put([context, stream])
187 |
--------------------------------------------------------------------------------
/cosyvoice/utils/executor.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | from contextlib import nullcontext
18 | import os
19 |
20 | import torch
21 | import torch.distributed as dist
22 |
23 | from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
24 |
25 |
26 | class Executor:
27 |
28 | def __init__(self, gan: bool = False):
29 | self.gan = gan
30 | self.step = 0
31 | self.epoch = 0
32 | self.rank = int(os.environ.get('RANK', 0))
33 | self.device = torch.device('cuda:{}'.format(self.rank))
34 |
35 | def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join):
36 | ''' Train one epoch
37 | '''
38 |
39 | lr = optimizer.param_groups[0]['lr']
40 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
41 | logging.info('using accumulate grad, new batch size is {} times'
42 | ' larger than before'.format(info_dict['accum_grad']))
43 | # A context manager to be used in conjunction with an instance of
44 | # torch.nn.parallel.DistributedDataParallel to be able to train
45 | # with uneven inputs across participating processes.
46 | model.train()
47 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
48 | with model_context():
49 | for batch_idx, batch_dict in enumerate(train_data_loader):
50 | info_dict["tag"] = "TRAIN"
51 | info_dict["step"] = self.step
52 | info_dict["epoch"] = self.epoch
53 | info_dict["batch_idx"] = batch_idx
54 | if cosyvoice_join(group_join, info_dict):
55 | break
56 |
57 | # Disable gradient synchronizations across DDP processes.
58 | # Within this context, gradients will be accumulated on module
59 | # variables, which will later be synchronized.
60 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
61 | context = model.no_sync
62 | # Used for single gpu training and DDP gradient synchronization
63 | # processes.
64 | else:
65 | context = nullcontext
66 |
67 | with context():
68 | info_dict = batch_forward(model, batch_dict, scaler, info_dict)
69 | info_dict = batch_backward(model, scaler, info_dict)
70 |
71 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
72 | log_per_step(writer, info_dict)
73 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
74 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
75 | (batch_idx + 1) % info_dict["accum_grad"] == 0:
76 | dist.barrier()
77 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
78 | model.train()
79 | if (batch_idx + 1) % info_dict["accum_grad"] == 0:
80 | self.step += 1
81 | dist.barrier()
82 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
83 |
84 | def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
85 | writer, info_dict, scaler, group_join):
86 | ''' Train one epoch
87 | '''
88 |
89 | lr = optimizer.param_groups[0]['lr']
90 | logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
91 | logging.info('using accumulate grad, new batch size is {} times'
92 | ' larger than before'.format(info_dict['accum_grad']))
93 | # A context manager to be used in conjunction with an instance of
94 | # torch.nn.parallel.DistributedDataParallel to be able to train
95 | # with uneven inputs across participating processes.
96 | model.train()
97 | model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
98 | with model_context():
99 | for batch_idx, batch_dict in enumerate(train_data_loader):
100 | info_dict["tag"] = "TRAIN"
101 | info_dict["step"] = self.step
102 | info_dict["epoch"] = self.epoch
103 | info_dict["batch_idx"] = batch_idx
104 | if cosyvoice_join(group_join, info_dict):
105 | break
106 |
107 | # Disable gradient synchronizations across DDP processes.
108 | # Within this context, gradients will be accumulated on module
109 | # variables, which will later be synchronized.
110 | if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
111 | context = model.no_sync
112 | # Used for single gpu training and DDP gradient synchronization
113 | # processes.
114 | else:
115 | context = nullcontext
116 |
117 | with context():
118 | batch_dict['turn'] = 'discriminator'
119 | info_dict = batch_forward(model, batch_dict, scaler, info_dict)
120 | info_dict = batch_backward(model, scaler, info_dict)
121 | info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
122 | optimizer.zero_grad()
123 | log_per_step(writer, info_dict)
124 | with context():
125 | batch_dict['turn'] = 'generator'
126 | info_dict = batch_forward(model, batch_dict, scaler, info_dict)
127 | info_dict = batch_backward(model, scaler, info_dict)
128 | info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
129 | optimizer_d.zero_grad()
130 | log_per_step(writer, info_dict)
131 | # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
132 | if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
133 | (batch_idx + 1) % info_dict["accum_grad"] == 0:
134 | dist.barrier()
135 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
136 | model.train()
137 | if (batch_idx + 1) % info_dict["accum_grad"] == 0:
138 | self.step += 1
139 | dist.barrier()
140 | self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
141 |
142 | @torch.inference_mode()
143 | def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
144 | ''' Cross validation on
145 | '''
146 | logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
147 | model.eval()
148 | total_num_utts, total_loss_dict = 0, {} # avoid division by 0
149 | for batch_idx, batch_dict in enumerate(cv_data_loader):
150 | info_dict["tag"] = "CV"
151 | info_dict["step"] = self.step
152 | info_dict["epoch"] = self.epoch
153 | info_dict["batch_idx"] = batch_idx
154 |
155 | num_utts = len(batch_dict["utts"])
156 | total_num_utts += num_utts
157 |
158 | if self.gan is True:
159 | batch_dict['turn'] = 'generator'
160 | info_dict = batch_forward(model, batch_dict, None, info_dict)
161 |
162 | for k, v in info_dict['loss_dict'].items():
163 | if k not in total_loss_dict:
164 | total_loss_dict[k] = []
165 | total_loss_dict[k].append(v.item() * num_utts)
166 | log_per_step(None, info_dict)
167 | for k, v in total_loss_dict.items():
168 | total_loss_dict[k] = sum(v) / total_num_utts
169 | info_dict['loss_dict'] = total_loss_dict
170 | log_per_save(writer, info_dict)
171 | model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
172 | save_model(model, model_name, info_dict)
173 |
--------------------------------------------------------------------------------
/cosyvoice/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
2 | # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import json
19 | import torch
20 | import torchaudio
21 | import logging
22 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
23 | logging.basicConfig(level=logging.DEBUG,
24 | format='%(asctime)s %(levelname)s %(message)s')
25 |
26 |
27 | def read_lists(list_file):
28 | lists = []
29 | with open(list_file, 'r', encoding='utf8') as fin:
30 | for line in fin:
31 | lists.append(line.strip())
32 | return lists
33 |
34 |
35 | def read_json_lists(list_file):
36 | lists = read_lists(list_file)
37 | results = {}
38 | for fn in lists:
39 | with open(fn, 'r', encoding='utf8') as fin:
40 | results.update(json.load(fin))
41 | return results
42 |
43 |
44 | def load_wav(wav, target_sr):
45 | speech, sample_rate = torchaudio.load(wav, backend='soundfile')
46 | speech = speech.mean(dim=0, keepdim=True)
47 | if sample_rate != target_sr:
48 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
49 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
50 | return speech
51 |
52 |
53 | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
54 | import tensorrt as trt
55 | logging.info("Converting onnx to trt...")
56 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
57 | logger = trt.Logger(trt.Logger.INFO)
58 | builder = trt.Builder(logger)
59 | network = builder.create_network(network_flags)
60 | parser = trt.OnnxParser(network, logger)
61 | config = builder.create_builder_config()
62 | config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
63 | if fp16:
64 | config.set_flag(trt.BuilderFlag.FP16)
65 | profile = builder.create_optimization_profile()
66 | # load onnx model
67 | with open(onnx_model, "rb") as f:
68 | if not parser.parse(f.read()):
69 | for error in range(parser.num_errors):
70 | print(parser.get_error(error))
71 | raise ValueError('failed to parse {}'.format(onnx_model))
72 | # set input shapes
73 | for i in range(len(trt_kwargs['input_names'])):
74 | profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
75 | tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
76 | # set input and output data type
77 | for i in range(network.num_inputs):
78 | input_tensor = network.get_input(i)
79 | input_tensor.dtype = tensor_dtype
80 | for i in range(network.num_outputs):
81 | output_tensor = network.get_output(i)
82 | output_tensor.dtype = tensor_dtype
83 | config.add_optimization_profile(profile)
84 | engine_bytes = builder.build_serialized_network(network, config)
85 | # save trt engine
86 | with open(trt_model, "wb") as f:
87 | f.write(engine_bytes)
88 | logging.info("Succesfully convert onnx to trt...")
89 |
90 |
91 | def export_cosyvoice2_vllm(model, model_path, device):
92 | if os.path.exists(model_path):
93 | return
94 | pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64
95 | vocab_size = model.speech_embedding.num_embeddings
96 | feature_size = model.speech_embedding.embedding_dim
97 | pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to
98 |
99 | dtype = torch.bfloat16
100 | # lm_head
101 | new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True)
102 | with torch.no_grad():
103 | new_lm_head.weight[:vocab_size] = model.llm_decoder.weight
104 | new_lm_head.bias[:vocab_size] = model.llm_decoder.bias
105 | new_lm_head.weight[vocab_size:] = 0
106 | new_lm_head.bias[vocab_size:] = 0
107 | model.llm.model.lm_head = new_lm_head
108 | new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size)
109 | # embed_tokens
110 | embed_tokens = model.llm.model.model.embed_tokens
111 | with torch.no_grad():
112 | new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight
113 | new_codec_embed.weight[vocab_size:] = 0
114 | model.llm.model.set_input_embeddings(new_codec_embed)
115 | model.llm.model.to(device)
116 | model.llm.model.to(dtype)
117 | tmp_vocab_size = model.llm.model.config.vocab_size
118 | tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
119 | del model.llm.model.generation_config.eos_token_id
120 | del model.llm.model.config.bos_token_id
121 | del model.llm.model.config.eos_token_id
122 | model.llm.model.config.vocab_size = pad_vocab_size
123 | model.llm.model.config.tie_word_embeddings = False
124 | model.llm.model.config.use_bias = True
125 | model.llm.model.save_pretrained(model_path)
126 | os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
127 | model.llm.model.config.vocab_size = tmp_vocab_size
128 | model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
129 | model.llm.model.set_input_embeddings(embed_tokens)
130 |
--------------------------------------------------------------------------------
/cosyvoice/utils/frontend_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 | import regex
17 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
18 |
19 |
20 | # whether contain chinese character
21 | def contains_chinese(text):
22 | return bool(chinese_char_pattern.search(text))
23 |
24 |
25 | # replace special symbol
26 | def replace_corner_mark(text):
27 | text = text.replace('²', '平方')
28 | text = text.replace('³', '立方')
29 | return text
30 |
31 |
32 | # remove meaningless symbol
33 | def remove_bracket(text):
34 | text = text.replace('(', '').replace(')', '')
35 | text = text.replace('【', '').replace('】', '')
36 | text = text.replace('`', '').replace('`', '')
37 | text = text.replace("——", " ")
38 | return text
39 |
40 |
41 | # spell Arabic numerals
42 | def spell_out_number(text: str, inflect_parser):
43 | new_text = []
44 | st = None
45 | for i, c in enumerate(text):
46 | if not c.isdigit():
47 | if st is not None:
48 | num_str = inflect_parser.number_to_words(text[st: i])
49 | new_text.append(num_str)
50 | st = None
51 | new_text.append(c)
52 | else:
53 | if st is None:
54 | st = i
55 | if st is not None and st < len(text):
56 | num_str = inflect_parser.number_to_words(text[st:])
57 | new_text.append(num_str)
58 | return ''.join(new_text)
59 |
60 |
61 | # split paragrah logic:
62 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
63 | # 2. cal sentence len according to lang
64 | # 3. split sentence according to puncatation
65 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
66 | def calc_utt_length(_text: str):
67 | if lang == "zh":
68 | return len(_text)
69 | else:
70 | return len(tokenize(_text))
71 |
72 | def should_merge(_text: str):
73 | if lang == "zh":
74 | return len(_text) < merge_len
75 | else:
76 | return len(tokenize(_text)) < merge_len
77 |
78 | if lang == "zh":
79 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
80 | else:
81 | pounc = ['.', '?', '!', ';', ':']
82 | if comma_split:
83 | pounc.extend([',', ','])
84 |
85 | if text[-1] not in pounc:
86 | if lang == "zh":
87 | text += "。"
88 | else:
89 | text += "."
90 |
91 | st = 0
92 | utts = []
93 | for i, c in enumerate(text):
94 | if c in pounc:
95 | if len(text[st: i]) > 0:
96 | utts.append(text[st: i] + c)
97 | if i + 1 < len(text) and text[i + 1] in ['"', '”']:
98 | tmp = utts.pop(-1)
99 | utts.append(tmp + text[i + 1])
100 | st = i + 2
101 | else:
102 | st = i + 1
103 |
104 | final_utts = []
105 | cur_utt = ""
106 | for utt in utts:
107 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
108 | final_utts.append(cur_utt)
109 | cur_utt = ""
110 | cur_utt = cur_utt + utt
111 | if len(cur_utt) > 0:
112 | if should_merge(cur_utt) and len(final_utts) != 0:
113 | final_utts[-1] = final_utts[-1] + cur_utt
114 | else:
115 | final_utts.append(cur_utt)
116 |
117 | return final_utts
118 |
119 |
120 | # remove blank between chinese character
121 | def replace_blank(text: str):
122 | out_str = []
123 | for i, c in enumerate(text):
124 | if c == " ":
125 | if ((text[i + 1].isascii() and text[i + 1] != " ") and
126 | (text[i - 1].isascii() and text[i - 1] != " ")):
127 | out_str.append(c)
128 | else:
129 | out_str.append(c)
130 | return "".join(out_str)
131 |
132 |
133 | def is_only_punctuation(text):
134 | # Regular expression: Match strings that consist only of punctuation marks or are empty.
135 | punctuation_pattern = r'^[\p{P}\p{S}]*$'
136 | return bool(regex.fullmatch(punctuation_pattern, text))
137 |
--------------------------------------------------------------------------------
/cosyvoice/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
6 | loss = 0
7 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
8 | m_DG = torch.median((dr - dg))
9 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
10 | loss += tau - F.relu(tau - L_rel)
11 | return loss
12 |
13 |
14 | def mel_loss(real_speech, generated_speech, mel_transforms):
15 | loss = 0
16 | for transform in mel_transforms:
17 | mel_r = transform(real_speech)
18 | mel_g = transform(generated_speech)
19 | loss += F.l1_loss(mel_g, mel_r)
20 | return loss
21 |
--------------------------------------------------------------------------------
/cosyvoice/utils/losses_dpo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from typing import Tuple
4 |
5 |
6 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
7 | loss = 0
8 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
9 | m_DG = torch.median((dr - dg))
10 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
11 | loss += tau - F.relu(tau - L_rel)
12 | return loss
13 |
14 |
15 | def mel_loss(real_speech, generated_speech, mel_transforms):
16 | loss = 0
17 | for transform in mel_transforms:
18 | mel_r = transform(real_speech)
19 | mel_g = transform(generated_speech)
20 | loss += F.l1_loss(mel_g, mel_r)
21 | return loss
22 |
23 |
24 | class DPOLoss(torch.nn.Module):
25 | """
26 | DPO Loss
27 | """
28 |
29 | def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
30 | super().__init__()
31 | self.beta = beta
32 | self.label_smoothing = label_smoothing
33 | self.ipo = ipo
34 |
35 | def forward(
36 | self,
37 | policy_chosen_logps: torch.Tensor,
38 | policy_rejected_logps: torch.Tensor,
39 | reference_chosen_logps: torch.Tensor,
40 | reference_rejected_logps: torch.Tensor,
41 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
42 | pi_logratios = policy_chosen_logps - policy_rejected_logps
43 | ref_logratios = reference_chosen_logps - reference_rejected_logps
44 | logits = pi_logratios - ref_logratios
45 | if self.ipo:
46 | losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
47 | else:
48 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
49 | losses = (
50 | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
51 | - F.logsigmoid(-self.beta * logits) * self.label_smoothing
52 | )
53 | loss = losses.mean()
54 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
55 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
56 |
57 | return loss, chosen_rewards, rejected_rewards
58 |
--------------------------------------------------------------------------------
/cosyvoice/vllm/cosyvoice2.py:
--------------------------------------------------------------------------------
1 | # SPDX-License-Identifier: Apache-2.0
2 |
3 | # Adapted from
4 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
5 | # Copyright 2024 The Qwen team.
6 | # Copyright 2023 The vLLM team.
7 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8 | #
9 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10 | # and OPT implementations in this library. It has been modified from its
11 | # original forms to accommodate minor architectural differences compared
12 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13 | #
14 | # Licensed under the Apache License, Version 2.0 (the "License");
15 | # you may not use this file except in compliance with the License.
16 | # You may obtain a copy of the License at
17 | #
18 | # http://www.apache.org/licenses/LICENSE-2.0
19 | #
20 | # Unless required by applicable law or agreed to in writing, software
21 | # distributed under the License is distributed on an "AS IS" BASIS,
22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23 | # See the License for the specific language governing permissions and
24 | # limitations under the License.
25 | """Inference-only Qwen2 model compatible with HuggingFace weights."""
26 | from vllm.model_executor.models.qwen2 import *
27 |
28 |
29 | class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
30 | packed_modules_mapping = {
31 | "qkv_proj": [
32 | "q_proj",
33 | "k_proj",
34 | "v_proj",
35 | ],
36 | "gate_up_proj": [
37 | "gate_proj",
38 | "up_proj",
39 | ],
40 | }
41 |
42 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
43 | super().__init__()
44 | config = vllm_config.model_config.hf_config
45 | quant_config = vllm_config.quant_config
46 | lora_config = vllm_config.lora_config
47 |
48 | self.config = config
49 | self.lora_config = lora_config
50 |
51 | self.quant_config = quant_config
52 | self.model = Qwen2Model(vllm_config=vllm_config,
53 | prefix=maybe_prefix(prefix, "model"))
54 |
55 | if get_pp_group().is_last_rank:
56 | if config.tie_word_embeddings:
57 | self.lm_head = self.model.embed_tokens
58 | else:
59 | self.lm_head = ParallelLMHead(config.vocab_size,
60 | config.hidden_size,
61 | True,
62 | quant_config=quant_config,
63 | prefix=maybe_prefix(
64 | prefix, "lm_head"))
65 | else:
66 | self.lm_head = PPMissingLayer()
67 |
68 | self.logits_processor = LogitsProcessor(config.vocab_size)
69 |
70 | self.make_empty_intermediate_tensors = (
71 | self.model.make_empty_intermediate_tensors)
72 |
73 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
74 | return self.model.get_input_embeddings(input_ids)
75 |
76 | def forward(
77 | self,
78 | input_ids: torch.Tensor,
79 | positions: torch.Tensor,
80 | intermediate_tensors: Optional[IntermediateTensors] = None,
81 | inputs_embeds: Optional[torch.Tensor] = None,
82 | ) -> Union[torch.Tensor, IntermediateTensors]:
83 | hidden_states = self.model(input_ids, positions, intermediate_tensors,
84 | inputs_embeds)
85 | return hidden_states
86 |
87 | def compute_logits(
88 | self,
89 | hidden_states: torch.Tensor,
90 | sampling_metadata: SamplingMetadata,
91 | ) -> Optional[torch.Tensor]:
92 | logits = self.logits_processor(self.lm_head, hidden_states,
93 | sampling_metadata, self.lm_head.bias)
94 | return logits
95 |
96 | def load_weights(self, weights: Iterable[tuple[str,
97 | torch.Tensor]]) -> set[str]:
98 | loader = AutoWeightsLoader(
99 | self,
100 | skip_prefixes=(["lm_head."]
101 | if self.config.tie_word_embeddings else None),
102 | )
103 | return loader.load_weights(weights)
104 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
2 |
3 | ARG VENV_NAME="cosyvoice"
4 | ENV VENV=$VENV_NAME
5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
6 |
7 | ENV DEBIAN_FRONTEN=noninteractive
8 | ENV PYTHONUNBUFFERED=1
9 | SHELL ["/bin/bash", "--login", "-c"]
10 |
11 | RUN apt-get update -y --fix-missing
12 | RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \
13 | apt-get clean && \
14 | git lfs install
15 |
16 | # ==================================================================
17 | # conda install and conda forge channel as default
18 | # ------------------------------------------------------------------
19 | # Install miniforge
20 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \
21 | /bin/bash ~/miniforge.sh -b -p /opt/conda && \
22 | rm ~/miniforge.sh && \
23 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
24 | echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \
25 | echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
26 | echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \
27 | echo "conda activate ${VENV}" >> $HOME/.bashrc
28 |
29 | ENV PATH /opt/conda/bin:$PATH
30 |
31 | RUN conda config --add channels conda-forge && \
32 | conda config --set channel_priority strict
33 | # ------------------------------------------------------------------
34 | # ~conda
35 | # ==================================================================
36 |
37 | RUN conda create -y -n ${VENV} python=3.10
38 | ENV CONDA_DEFAULT_ENV=${VENV}
39 | ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH
40 |
41 | WORKDIR /workspace
42 |
43 | ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS"
44 |
45 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
46 |
47 | RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
48 | RUN conda activate ${VENV} && cd CosyVoice && \
49 | pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
50 |
51 | WORKDIR /workspace/CosyVoice
52 |
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/conf/cosyvoice_dpo.yaml:
--------------------------------------------------------------------------------
1 | # set random seed, so that you may reproduce your result.
2 | __set_seed1: !apply:random.seed [1986]
3 | __set_seed2: !apply:numpy.random.seed [1986]
4 | __set_seed3: !apply:torch.manual_seed [1986]
5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6 |
7 | # fixed params
8 | sample_rate: 24000 # 16000 for llm, 24000 for cfm
9 | llm_input_size: 896
10 | llm_output_size: 896
11 | spk_embed_dim: 192
12 | qwen_pretrain_path: 'CosyVoice2-0.5B/CosyVoice-BlankEN'
13 |
14 | # model params
15 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml.
16 | # for system/third_party class/function, we do not require this.
17 | llm: !new:cosyvoice.llm.llm_dpo.Qwen2LM
18 | llm_input_size: !ref
19 | llm_output_size: !ref
20 | speech_token_size: 6561
21 | length_normalized_loss: True
22 | lsm_weight: 0
23 | dpo: True
24 | llm: !new:cosyvoice.llm.llm.Qwen2Encoder
25 | pretrain_path: !ref
26 | sampling: !name:cosyvoice.utils.common.ras_sampling
27 | top_p: 0.8
28 | top_k: 25
29 | win_size: 10
30 | tau_r: 0.1
31 | flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
32 | input_size: 512
33 | output_size: 80
34 | spk_embed_dim: !ref
35 | output_type: 'mel'
36 | vocab_size: 6561
37 | input_frame_rate: 25
38 | only_mask_loss: True
39 | token_mel_ratio: 2
40 | pre_lookahead_len: 3
41 | encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
42 | output_size: 512
43 | attention_heads: 8
44 | linear_units: 2048
45 | num_blocks: 6
46 | dropout_rate: 0.1
47 | positional_dropout_rate: 0.1
48 | attention_dropout_rate: 0.1
49 | normalize_before: True
50 | input_layer: 'linear'
51 | pos_enc_layer_type: 'rel_pos_espnet'
52 | selfattention_layer_type: 'rel_selfattn'
53 | input_size: 512
54 | use_cnn_module: False
55 | macaron_style: False
56 | decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
57 | in_channels: 240
58 | n_spks: 1
59 | spk_emb_dim: 80
60 | cfm_params: !new:omegaconf.DictConfig
61 | content:
62 | sigma_min: 1e-06
63 | solver: 'euler'
64 | t_scheduler: 'cosine'
65 | training_cfg_rate: 0.2
66 | inference_cfg_rate: 0.7
67 | reg_loss_type: 'l1'
68 | estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
69 | in_channels: 320
70 | out_channels: 80
71 | causal: True
72 | channels: [256]
73 | dropout: 0.0
74 | attention_head_dim: 64
75 | n_blocks: 4
76 | num_mid_blocks: 12
77 | num_heads: 8
78 | act_fn: 'gelu'
79 |
80 | hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
81 | in_channels: 80
82 | base_channels: 512
83 | nb_harmonics: 8
84 | sampling_rate: !ref
85 | nsf_alpha: 0.1
86 | nsf_sigma: 0.003
87 | nsf_voiced_threshold: 10
88 | upsample_rates: [8, 5, 3]
89 | upsample_kernel_sizes: [16, 11, 7]
90 | istft_params:
91 | n_fft: 16
92 | hop_len: 4
93 | resblock_kernel_sizes: [3, 7, 11]
94 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
95 | source_resblock_kernel_sizes: [7, 7, 11]
96 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
97 | lrelu_slope: 0.1
98 | audio_limit: 0.99
99 | f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
100 | num_class: 1
101 | in_channels: 80
102 | cond_channels: 512
103 |
104 | # gan related module
105 | mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
106 | n_fft: 1024
107 | num_mels: 80
108 | sampling_rate: !ref
109 | hop_size: 256
110 | win_size: 1024
111 | fmin: 0
112 | fmax: null
113 | center: False
114 | hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
115 | generator: !ref
116 | discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
117 | mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
118 | mrd: !new:cosyvoice.hifigan.discriminator.MultiResolutionDiscriminator
119 | mel_spec_transform: [
120 | !ref
121 | ]
122 |
123 | # processor functions
124 | parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
125 | get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
126 | multilingual: True
127 | num_languages: 100
128 | language: 'en'
129 | task: 'transcribe'
130 | allowed_special: 'all'
131 | tokenize: !name:cosyvoice.dataset.processor.tokenize
132 | get_tokenizer: !ref
133 | allowed_special: !ref
134 | filter: !name:cosyvoice.dataset.processor.filter
135 | max_length: 40960
136 | min_length: 0
137 | token_max_length: 200
138 | token_min_length: 1
139 | resample: !name:cosyvoice.dataset.processor.resample
140 | resample_rate: !ref
141 | truncate: !name:cosyvoice.dataset.processor.truncate
142 | truncate_length: 24576 # must be a multiplier of hop_size
143 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram
144 | n_fft: 1024
145 | num_mels: 80
146 | sampling_rate: !ref
147 | hop_size: 256
148 | win_size: 1024
149 | fmin: 0
150 | fmax: 8000
151 | center: False
152 | compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
153 | feat_extractor: !ref
154 | compute_f0: !name:cosyvoice.dataset.processor.compute_f0
155 | sample_rate: !ref
156 | hop_size: 256
157 | parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
158 | normalize: True
159 | shuffle: !name:cosyvoice.dataset.processor.shuffle
160 | shuffle_size: 1000
161 | sort: !name:cosyvoice.dataset.processor.sort
162 | sort_size: 500 # sort_size should be less than shuffle_size
163 | batch: !name:cosyvoice.dataset.processor.batch
164 | batch_type: 'dynamic'
165 | max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
166 | padding: !name:cosyvoice.dataset.processor.padding
167 | use_spk_embedding: True # change to True during sft
168 | dpo: True
169 |
170 | # dataset processor pipeline
171 | data_pipeline: [
172 | !ref ,
173 | !ref ,
174 | !ref ,
175 | !ref ,
176 | !ref ,
177 | !ref ,
178 | !ref ,
179 | !ref ,
180 | !ref ,
181 | !ref ,
182 | ]
183 | data_pipeline_gan: [
184 | !ref ,
185 | !ref ,
186 | !ref ,
187 | !ref ,
188 | !ref ,
189 | !ref ,
190 | !ref ,
191 | !ref ,
192 | !ref ,
193 | !ref ,
194 | !ref ,
195 | !ref ,
196 | ]
197 |
198 | # llm flow train conf
199 | train_conf:
200 | optim: adam
201 | optim_conf:
202 | lr: 0.00001 # change to 1e-5 during sft
203 | scheduler: warmuplr # change to constantlr during sft
204 | scheduler_conf:
205 | warmup_steps: 25000
206 | max_epoch: 200
207 | grad_clip: 5
208 | accum_grad: 2
209 | log_interval: 100
210 | save_per_step: -1
211 |
212 | # gan train conf
213 | train_conf_gan:
214 | optim: adam
215 | optim_conf:
216 | lr: 0.0002 # use small lr for gan training
217 | scheduler: constantlr
218 | optim_d: adam
219 | optim_conf_d:
220 | lr: 0.0002 # use small lr for gan training
221 | scheduler_d: constantlr
222 | max_epoch: 200
223 | grad_clip: 5
224 | accum_grad: 1 # in gan training, accum_grad must be 1
225 | log_interval: 100
226 | save_per_step: -1
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/conf/ds_stage2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 1,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 5,
6 | "fp16": {
7 | "enabled": false,
8 | "auto_cast": false,
9 | "loss_scale": 0,
10 | "initial_scale_power": 16,
11 | "loss_scale_window": 256,
12 | "hysteresis": 2,
13 | "consecutive_hysteresis": false,
14 | "min_loss_scale": 1
15 | },
16 | "bf16": {
17 | "enabled": false
18 | },
19 | "zero_force_ds_cpu_optimizer": false,
20 | "zero_optimization": {
21 | "stage": 2,
22 | "offload_optimizer": {
23 | "device": "none",
24 | "pin_memory": true
25 | },
26 | "allgather_partitions": true,
27 | "allgather_bucket_size": 5e8,
28 | "overlap_comm": false,
29 | "reduce_scatter": true,
30 | "reduce_bucket_size": 5e8,
31 | "contiguous_gradients" : true
32 | },
33 | "optimizer": {
34 | "type": "AdamW",
35 | "params": {
36 | "lr": 0.001,
37 | "weight_decay": 0.0001,
38 | "torch_adam": true,
39 | "adam_w_mode": true
40 | }
41 | }
42 | }
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/cosyvoice:
--------------------------------------------------------------------------------
1 | ../../../cosyvoice
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/local/download_and_untar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 | # Apache 2.0
5 |
6 | remove_archive=false
7 |
8 | if [ "$1" == --remove-archive ]; then
9 | remove_archive=true
10 | shift
11 | fi
12 |
13 | if [ $# -ne 3 ]; then
14 | echo "Usage: $0 [--remove-archive] "
15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it."
17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other,"
18 | echo " train-clean-100, train-clean-360, train-other-500."
19 | exit 1
20 | fi
21 |
22 | data=$1
23 | url=$2
24 | part=$3
25 |
26 | if [ ! -d "$data" ]; then
27 | echo "$0: no such directory $data"
28 | exit 1
29 | fi
30 |
31 | part_ok=false
32 | list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
33 | for x in $list; do
34 | if [ "$part" == $x ]; then part_ok=true; fi
35 | done
36 | if ! $part_ok; then
37 | echo "$0: expected to be one of $list, but got '$part'"
38 | exit 1
39 | fi
40 |
41 | if [ -z "$url" ]; then
42 | echo "$0: empty URL base."
43 | exit 1
44 | fi
45 |
46 | if [ -f $data/LibriTTS/$part/.complete ]; then
47 | echo "$0: data part $part was already successfully extracted, nothing to do."
48 | exit 0
49 | fi
50 |
51 |
52 | # sizes of the archive files in bytes. This is some older versions.
53 | sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of
55 | # things we probably won't download.
56 | sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
57 |
58 | if [ -f $data/$part.tar.gz ]; then
59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
60 | size_ok=false
61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
62 | if ! $size_ok; then
63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
64 | echo "does not equal the size of one of the archives."
65 | rm $data/$part.tar.gz
66 | else
67 | echo "$data/$part.tar.gz exists and appears to be complete."
68 | fi
69 | fi
70 |
71 | if [ ! -f $data/$part.tar.gz ]; then
72 | if ! which wget >/dev/null; then
73 | echo "$0: wget is not installed."
74 | exit 1
75 | fi
76 | full_url=$url/$part.tar.gz
77 | echo "$0: downloading data from $full_url. This may take some time, please be patient."
78 |
79 | if ! wget -P $data --no-check-certificate $full_url; then
80 | echo "$0: error executing wget $full_url"
81 | exit 1
82 | fi
83 | fi
84 |
85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then
86 | echo "$0: error un-tarring archive $data/$part.tar.gz"
87 | exit 1
88 | fi
89 |
90 | touch $data/LibriTTS/$part/.complete
91 |
92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
93 |
94 | if $remove_archive; then
95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
96 | rm $data/$part.tar.gz
97 | fi
98 |
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import glob
4 | import os
5 | from tqdm import tqdm
6 |
7 |
8 | logger = logging.getLogger()
9 |
10 |
11 | def main():
12 | wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
13 |
14 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
15 | for wav in tqdm(wavs):
16 | txt = wav.replace('.wav', '.normalized.txt')
17 | if not os.path.exists(txt):
18 | logger.warning('{} do not exsist'.format(txt))
19 | continue
20 | with open(txt) as f:
21 | content = ''.join(l.replace('\n', '') for l in f.readline())
22 | utt = os.path.basename(wav).replace('.wav', '')
23 | spk = utt.split('_')[0]
24 | utt2wav[utt] = wav
25 | utt2text[utt] = content
26 | utt2spk[utt] = spk
27 | if spk not in spk2utt:
28 | spk2utt[spk] = []
29 | spk2utt[spk].append(utt)
30 |
31 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
32 | for k, v in utt2wav.items():
33 | f.write('{} {}\n'.format(k, v))
34 | with open('{}/text'.format(args.des_dir), 'w') as f:
35 | for k, v in utt2text.items():
36 | f.write('{} {}\n'.format(k, v))
37 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
38 | for k, v in utt2spk.items():
39 | f.write('{} {}\n'.format(k, v))
40 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
41 | for k, v in spk2utt.items():
42 | f.write('{} {}\n'.format(k, ' '.join(v)))
43 | return
44 |
45 |
46 | if __name__ == "__main__":
47 | parser = argparse.ArgumentParser()
48 | parser.add_argument('--src_dir',
49 | type=str)
50 | parser.add_argument('--des_dir',
51 | type=str)
52 | args = parser.parse_args()
53 | main()
54 |
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/path.sh:
--------------------------------------------------------------------------------
1 | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
2 | export PYTHONIOENCODING=UTF-8
3 | export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
4 |
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 Alibaba Inc. All Rights Reserved.
3 | . ./path.sh || exit 1;
4 |
5 | stage=-1
6 | stop_stage=3
7 |
8 | data_url=www.openslr.org/resources/60
9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
11 |
12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
13 | echo "Data Download"
14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part}
16 | done
17 | fi
18 |
19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
22 | mkdir -p data/$x
23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
24 | done
25 | fi
26 |
27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
30 | tools/extract_embedding.py --dir data/$x \
31 | --onnx_path $pretrained_model_dir/campplus.onnx
32 | done
33 | fi
34 |
35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
38 | tools/extract_speech_token.py --dir data/$x \
39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
40 | done
41 | fi
42 |
43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
46 | mkdir -p data/$x/parquet
47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \
48 | --num_processes 10 \
49 | --src_dir data/$x \
50 | --des_dir data/$x/parquet
51 | done
52 | fi
53 |
54 | # inference
55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data"
57 | for mode in sft zero_shot; do
58 | python cosyvoice/bin/inference.py --mode $mode \
59 | --gpu 0 \
60 | --config conf/cosyvoice.yaml \
61 | --prompt_data data/test-clean/parquet/data.list \
62 | --prompt_utt2data data/test-clean/parquet/utt2data.list \
63 | --tts_text `pwd`/tts_text.json \
64 | --llm_model $pretrained_model_dir/llm.pt \
65 | --flow_model $pretrained_model_dir/flow.pt \
66 | --hifigan_model $pretrained_model_dir/hift.pt \
67 | --result_dir `pwd`/exp/cosyvoice/test-clean/$mode
68 | done
69 | fi
70 |
71 | # train llm
72 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
73 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
74 | job_id=1986
75 | dist_backend="nccl"
76 | num_workers=2
77 | prefetch=100
78 | train_engine=torch_ddp
79 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
80 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
81 | if [ $train_engine == 'deepspeed' ]; then
82 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
83 | fi
84 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
85 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
86 | for model in llm flow hifigan; do
87 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
89 | cosyvoice/bin/train.py \
90 | --train_engine $train_engine \
91 | --config conf/cosyvoice.yaml \
92 | --train_data data/train.data.list \
93 | --cv_data data/dev.data.list \
94 | --model $model \
95 | --checkpoint $pretrained_model_dir/$model.pt \
96 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
97 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
98 | --ddp.dist_backend $dist_backend \
99 | --num_workers ${num_workers} \
100 | --prefetch ${prefetch} \
101 | --pin_memory \
102 | --use_amp \
103 | --deepspeed_config ./conf/ds_stage2.json \
104 | --deepspeed.save_states model+optimizer
105 | done
106 | fi
107 |
108 | # average model
109 | average_num=5
110 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
111 | for model in llm flow hifigan; do
112 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
113 | echo "do model average and final checkpoint is $decode_checkpoint"
114 | python cosyvoice/bin/average_model.py \
115 | --dst_model $decode_checkpoint \
116 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \
117 | --num ${average_num} \
118 | --val_best
119 | done
120 | fi
121 |
122 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
123 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
124 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
125 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
126 | fi
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/tools:
--------------------------------------------------------------------------------
1 | ../../../tools
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice/tts_text.json:
--------------------------------------------------------------------------------
1 | {
2 | "1089_134686_000002_000000": [
3 | "hello, my name is Jack. What is your name?"
4 | ]
5 | }
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/conf/cosyvoice2.yaml:
--------------------------------------------------------------------------------
1 | # set random seed, so that you may reproduce your result.
2 | __set_seed1: !apply:random.seed [1986]
3 | __set_seed2: !apply:numpy.random.seed [1986]
4 | __set_seed3: !apply:torch.manual_seed [1986]
5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1986]
6 |
7 | # fixed params
8 | sample_rate: 24000
9 | llm_input_size: 896
10 | llm_output_size: 896
11 | spk_embed_dim: 192
12 | qwen_pretrain_path: ''
13 | token_frame_rate: 25
14 | token_mel_ratio: 2
15 |
16 | # stream related params
17 | chunk_size: 25 # streaming inference chunk size, in token
18 | num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks
19 |
20 | # model params
21 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml.
22 | # for system/third_party class/function, we do not require this.
23 | llm: !new:cosyvoice.llm.llm.Qwen2LM
24 | llm_input_size: !ref
25 | llm_output_size: !ref
26 | speech_token_size: 6561
27 | length_normalized_loss: True
28 | lsm_weight: 0
29 | mix_ratio: [5, 15]
30 | llm: !new:cosyvoice.llm.llm.Qwen2Encoder
31 | pretrain_path: !ref
32 | sampling: !name:cosyvoice.utils.common.ras_sampling
33 | top_p: 0.8
34 | top_k: 25
35 | win_size: 10
36 | tau_r: 0.1
37 |
38 | flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec
39 | input_size: 512
40 | output_size: 80
41 | spk_embed_dim: !ref
42 | output_type: 'mel'
43 | vocab_size: 6561
44 | input_frame_rate: !ref
45 | only_mask_loss: True
46 | token_mel_ratio: !ref
47 | pre_lookahead_len: 3
48 | encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder
49 | output_size: 512
50 | attention_heads: 8
51 | linear_units: 2048
52 | num_blocks: 6
53 | dropout_rate: 0.1
54 | positional_dropout_rate: 0.1
55 | attention_dropout_rate: 0.1
56 | normalize_before: True
57 | input_layer: 'linear'
58 | pos_enc_layer_type: 'rel_pos_espnet'
59 | selfattention_layer_type: 'rel_selfattn'
60 | input_size: 512
61 | use_cnn_module: False
62 | macaron_style: False
63 | static_chunk_size: !ref
64 | decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM
65 | in_channels: 240
66 | n_spks: 1
67 | spk_emb_dim: 80
68 | cfm_params: !new:omegaconf.DictConfig
69 | content:
70 | sigma_min: 1e-06
71 | solver: 'euler'
72 | t_scheduler: 'cosine'
73 | training_cfg_rate: 0.2
74 | inference_cfg_rate: 0.7
75 | reg_loss_type: 'l1'
76 | estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder
77 | in_channels: 320
78 | out_channels: 80
79 | channels: [256]
80 | dropout: 0.0
81 | attention_head_dim: 64
82 | n_blocks: 4
83 | num_mid_blocks: 12
84 | num_heads: 8
85 | act_fn: 'gelu'
86 | static_chunk_size: !ref *
87 | num_decoding_left_chunks: !ref
88 |
89 | hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
90 | in_channels: 80
91 | base_channels: 512
92 | nb_harmonics: 8
93 | sampling_rate: !ref
94 | nsf_alpha: 0.1
95 | nsf_sigma: 0.003
96 | nsf_voiced_threshold: 10
97 | upsample_rates: [8, 5, 3]
98 | upsample_kernel_sizes: [16, 11, 7]
99 | istft_params:
100 | n_fft: 16
101 | hop_len: 4
102 | resblock_kernel_sizes: [3, 7, 11]
103 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
104 | source_resblock_kernel_sizes: [7, 7, 11]
105 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
106 | lrelu_slope: 0.1
107 | audio_limit: 0.99
108 | f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
109 | num_class: 1
110 | in_channels: 80
111 | cond_channels: 512
112 |
113 | # gan related module
114 | mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
115 | n_fft: 1920
116 | num_mels: 80
117 | sampling_rate: !ref
118 | hop_size: 480
119 | win_size: 1920
120 | fmin: 0
121 | fmax: null
122 | center: False
123 | hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
124 | generator: !ref
125 | discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
126 | mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
127 | mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
128 | mel_spec_transform: [
129 | !ref
130 | ]
131 |
132 | # processor functions
133 | parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
134 | get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer
135 | token_path: !ref
136 | skip_special_tokens: True
137 | allowed_special: 'all'
138 | tokenize: !name:cosyvoice.dataset.processor.tokenize
139 | get_tokenizer: !ref
140 | allowed_special: !ref
141 | filter: !name:cosyvoice.dataset.processor.filter
142 | max_length: 40960
143 | min_length: 100
144 | token_max_length: 200
145 | token_min_length: 1
146 | resample: !name:cosyvoice.dataset.processor.resample
147 | resample_rate: !ref
148 | truncate: !name:cosyvoice.dataset.processor.truncate
149 | truncate_length: 24480 # must be a multiplier of hop_size
150 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram
151 | n_fft: 1920
152 | num_mels: 80
153 | sampling_rate: !ref
154 | hop_size: 480
155 | win_size: 1920
156 | fmin: 0
157 | fmax: 8000
158 | center: False
159 | compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
160 | feat_extractor: !ref
161 | token_mel_ratio: 2
162 | compute_f0: !name:cosyvoice.dataset.processor.compute_f0
163 | sample_rate: !ref
164 | hop_size: 480
165 | parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
166 | normalize: True
167 | shuffle: !name:cosyvoice.dataset.processor.shuffle
168 | shuffle_size: 1000
169 | sort: !name:cosyvoice.dataset.processor.sort
170 | sort_size: 500 # sort_size should be less than shuffle_size
171 | batch: !name:cosyvoice.dataset.processor.batch
172 | batch_type: 'dynamic'
173 | max_frames_in_batch: 2000
174 | padding: !name:cosyvoice.dataset.processor.padding
175 | use_spk_embedding: False # change to True during sft
176 |
177 |
178 | # dataset processor pipeline
179 | data_pipeline: [
180 | !ref ,
181 | !ref ,
182 | !ref ,
183 | !ref ,
184 | !ref ,
185 | !ref ,
186 | !ref ,
187 | !ref ,
188 | !ref ,
189 | !ref ,
190 | ]
191 | data_pipeline_gan: [
192 | !ref ,
193 | !ref ,
194 | !ref ,
195 | !ref ,
196 | !ref ,
197 | !ref ,
198 | !ref ,
199 | !ref ,
200 | !ref ,
201 | !ref ,
202 | !ref ,
203 | !ref ,
204 | ]
205 |
206 | # llm flow train conf
207 | train_conf:
208 | optim: adam
209 | optim_conf:
210 | lr: 1e-5 # change to 1e-5 during sft
211 | scheduler: constantlr # change to constantlr during sft
212 | scheduler_conf:
213 | warmup_steps: 2500
214 | max_epoch: 200
215 | grad_clip: 5
216 | accum_grad: 2
217 | log_interval: 100
218 | save_per_step: -1
219 |
220 | # gan train conf
221 | train_conf_gan:
222 | optim: adam
223 | optim_conf:
224 | lr: 0.0002 # use small lr for gan training
225 | scheduler: constantlr
226 | optim_d: adam
227 | optim_conf_d:
228 | lr: 0.0002 # use small lr for gan training
229 | scheduler_d: constantlr
230 | max_epoch: 200
231 | grad_clip: 5
232 | accum_grad: 1 # in gan training, accum_grad must be 1
233 | log_interval: 100
234 | save_per_step: -1
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/conf/ds_stage2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 1,
3 | "gradient_accumulation_steps": 1,
4 | "steps_per_print": 100,
5 | "gradient_clipping": 5,
6 | "fp16": {
7 | "enabled": false,
8 | "auto_cast": false,
9 | "loss_scale": 0,
10 | "initial_scale_power": 16,
11 | "loss_scale_window": 256,
12 | "hysteresis": 2,
13 | "consecutive_hysteresis": false,
14 | "min_loss_scale": 1
15 | },
16 | "bf16": {
17 | "enabled": false
18 | },
19 | "zero_force_ds_cpu_optimizer": false,
20 | "zero_optimization": {
21 | "stage": 2,
22 | "offload_optimizer": {
23 | "device": "none",
24 | "pin_memory": true
25 | },
26 | "allgather_partitions": true,
27 | "allgather_bucket_size": 5e8,
28 | "overlap_comm": false,
29 | "reduce_scatter": true,
30 | "reduce_bucket_size": 5e8,
31 | "contiguous_gradients" : true
32 | },
33 | "optimizer": {
34 | "type": "AdamW",
35 | "params": {
36 | "lr": 0.001,
37 | "weight_decay": 0.0001,
38 | "torch_adam": true,
39 | "adam_w_mode": true
40 | }
41 | }
42 | }
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/cosyvoice:
--------------------------------------------------------------------------------
1 | ../../../cosyvoice
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/local:
--------------------------------------------------------------------------------
1 | ../cosyvoice/local
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/path.sh:
--------------------------------------------------------------------------------
1 | ../cosyvoice/path.sh
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 Alibaba Inc. All Rights Reserved.
3 | . ./path.sh || exit 1;
4 |
5 | stage=-1
6 | stop_stage=3
7 |
8 | data_url=www.openslr.org/resources/60
9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts
10 | pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B
11 |
12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
13 | echo "Data Download"
14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part}
16 | done
17 | fi
18 |
19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
22 | mkdir -p data/$x
23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x
24 | done
25 | fi
26 |
27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
30 | tools/extract_embedding.py --dir data/$x \
31 | --onnx_path $pretrained_model_dir/campplus.onnx
32 | done
33 | fi
34 |
35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
38 | tools/extract_speech_token.py --dir data/$x \
39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx
40 | done
41 | fi
42 |
43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do
46 | mkdir -p data/$x/parquet
47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \
48 | --num_processes 10 \
49 | --src_dir data/$x \
50 | --des_dir data/$x/parquet
51 | done
52 | fi
53 |
54 | # inference
55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data"
57 | # TODO consider remove bin/inference.py, or use similar initilization method as in readme
58 | for mode in sft zero_shot; do
59 | python cosyvoice/bin/inference.py --mode $mode \
60 | --gpu 0 \
61 | --config conf/cosyvoice2.yaml \
62 | --prompt_data data/test-clean/parquet/data.list \
63 | --prompt_utt2data data/test-clean/parquet/utt2data.list \
64 | --tts_text `pwd`/tts_text.json \
65 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
66 | --llm_model $pretrained_model_dir/llm.pt \
67 | --flow_model $pretrained_model_dir/flow.pt \
68 | --hifigan_model $pretrained_model_dir/hift.pt \
69 | --result_dir `pwd`/exp/cosyvoice/test-clean/$mode
70 | done
71 | fi
72 |
73 | # train llm
74 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
75 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
76 | job_id=1986
77 | dist_backend="nccl"
78 | num_workers=2
79 | prefetch=100
80 | train_engine=torch_ddp
81 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
82 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
83 | if [ $train_engine == 'deepspeed' ]; then
84 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
85 | fi
86 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
87 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
88 | # NOTE will update llm/hift training later
89 | for model in llm flow; do
90 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \
91 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \
92 | cosyvoice/bin/train.py \
93 | --train_engine $train_engine \
94 | --config conf/cosyvoice2.yaml \
95 | --train_data data/train.data.list \
96 | --cv_data data/dev.data.list \
97 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \
98 | --model $model \
99 | --checkpoint $pretrained_model_dir/$model.pt \
100 | --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \
101 | --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \
102 | --ddp.dist_backend $dist_backend \
103 | --num_workers ${num_workers} \
104 | --prefetch ${prefetch} \
105 | --pin_memory \
106 | --use_amp \
107 | --deepspeed_config ./conf/ds_stage2.json \
108 | --deepspeed.save_states model+optimizer
109 | done
110 | fi
111 |
112 | # average model
113 | average_num=5
114 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
115 | for model in llm flow hifigan; do
116 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
117 | echo "do model average and final checkpoint is $decode_checkpoint"
118 | python cosyvoice/bin/average_model.py \
119 | --dst_model $decode_checkpoint \
120 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \
121 | --num ${average_num} \
122 | --val_best
123 | done
124 | fi
125 |
126 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
127 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
128 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
129 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
130 | fi
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/tools:
--------------------------------------------------------------------------------
1 | ../../../tools
--------------------------------------------------------------------------------
/examples/libritts/cosyvoice2/tts_text.json:
--------------------------------------------------------------------------------
1 | ../cosyvoice/tts_text.json
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/conf:
--------------------------------------------------------------------------------
1 | ../../libritts/cosyvoice/conf
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/cosyvoice:
--------------------------------------------------------------------------------
1 | ../../../cosyvoice
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/local/download_and_untar.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey)
4 | # Apache 2.0
5 |
6 | remove_archive=false
7 |
8 | if [ "$1" == --remove-archive ]; then
9 | remove_archive=true
10 | shift
11 | fi
12 |
13 | if [ $# -ne 3 ]; then
14 | echo "Usage: $0 [--remove-archive] "
15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it."
17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other,"
18 | echo " train-clean-100, train-clean-360, train-other-500."
19 | exit 1
20 | fi
21 |
22 | data=$1
23 | url=$2
24 | part=$3
25 |
26 | if [ ! -d "$data" ]; then
27 | echo "$0: no such directory $data"
28 | exit 1
29 | fi
30 |
31 | part_ok=false
32 | list="dev_set test_set train_set"
33 | for x in $list; do
34 | if [ "$part" == $x ]; then part_ok=true; fi
35 | done
36 | if ! $part_ok; then
37 | echo "$0: expected to be one of $list, but got '$part'"
38 | exit 1
39 | fi
40 |
41 | if [ -z "$url" ]; then
42 | echo "$0: empty URL base."
43 | exit 1
44 | fi
45 |
46 | if [ -f $data/.$part.complete ]; then
47 | echo "$0: data part $part was already successfully extracted, nothing to do."
48 | exit 0
49 | fi
50 |
51 |
52 | # sizes of the archive files in bytes. This is some older versions.
53 | sizes_old="1035537823 2201936013 52627842921"
54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of
55 | # things we probably won't download.
56 | sizes_new="3886385"
57 |
58 | if [ -f $data/$part.tar.gz ]; then
59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
60 | size_ok=false
61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
62 | if ! $size_ok; then
63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
64 | echo "does not equal the size of one of the archives."
65 | rm $data/$part.tar.gz
66 | else
67 | echo "$data/$part.tar.gz exists and appears to be complete."
68 | fi
69 | fi
70 |
71 | if [ ! -f $data/$part.tar.gz ]; then
72 | if ! which wget >/dev/null; then
73 | echo "$0: wget is not installed."
74 | exit 1
75 | fi
76 | full_url=$url/$part.tar.gz
77 | echo "$0: downloading data from $full_url. This may take some time, please be patient."
78 |
79 | if ! wget -P $data --no-check-certificate $full_url; then
80 | echo "$0: error executing wget $full_url"
81 | exit 1
82 | fi
83 | fi
84 |
85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then
86 | echo "$0: error un-tarring archive $data/$part.tar.gz"
87 | exit 1
88 | fi
89 |
90 | touch $data/.$part.complete
91 |
92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
93 |
94 | if $remove_archive; then
95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
96 | rm $data/$part.tar.gz
97 | fi
98 |
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/local/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | from tqdm import tqdm
5 |
6 |
7 | logger = logging.getLogger()
8 |
9 |
10 | def main():
11 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
12 | with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
13 | lines = f.readlines()[1:]
14 | lines = [l.split('\t') for l in lines]
15 | for wav, spk, content in tqdm(lines):
16 | wav, spk, content = wav.strip(), spk.strip(), content.strip()
17 | content = content.replace('[FIL]', '')
18 | content = content.replace('[SPK]', '')
19 | wav = os.path.join(args.src_dir, spk, wav)
20 | if not os.path.exists(wav):
21 | continue
22 | utt = os.path.basename(wav).replace('.wav', '')
23 | utt2wav[utt] = wav
24 | utt2text[utt] = content
25 | utt2spk[utt] = spk
26 | if spk not in spk2utt:
27 | spk2utt[spk] = []
28 | spk2utt[spk].append(utt)
29 |
30 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
31 | for k, v in utt2wav.items():
32 | f.write('{} {}\n'.format(k, v))
33 | with open('{}/text'.format(args.des_dir), 'w') as f:
34 | for k, v in utt2text.items():
35 | f.write('{} {}\n'.format(k, v))
36 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
37 | for k, v in utt2spk.items():
38 | f.write('{} {}\n'.format(k, v))
39 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
40 | for k, v in spk2utt.items():
41 | f.write('{} {}\n'.format(k, ' '.join(v)))
42 | return
43 |
44 |
45 | if __name__ == "__main__":
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('--src_dir',
48 | type=str)
49 | parser.add_argument('--des_dir',
50 | type=str)
51 | args = parser.parse_args()
52 | main()
53 |
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/path.sh:
--------------------------------------------------------------------------------
1 | ../../libritts/cosyvoice/path.sh
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # Copyright 2024 Alibaba Inc. All Rights Reserved.
3 | . ./path.sh || exit 1;
4 |
5 | stage=-1
6 | stop_stage=3
7 |
8 | data_url=www.openslr.org/resources/68
9 | data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read
10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M
11 |
12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
13 | echo "Data Download"
14 | for part in dev_set test_set train_set; do
15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part}
16 | done
17 | fi
18 |
19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt"
21 | for x in dev test train; do
22 | mkdir -p data/$x
23 | python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x
24 | done
25 | fi
26 |
27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir"
29 | for x in dev test train; do
30 | tools/extract_embedding.py --dir data/$x \
31 | --onnx_path $pretrained_model_dir/campplus.onnx
32 | done
33 | fi
34 |
35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir"
37 | for x in dev test train; do
38 | tools/extract_speech_token.py --dir data/$x \
39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx
40 | done
41 | fi
42 |
43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt"
45 | for x in dev test train; do
46 | mkdir -p data/$x/parquet
47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \
48 | --num_processes 10 \
49 | --src_dir data/$x \
50 | --des_dir data/$x/parquet
51 | done
52 | fi
53 |
54 | # inference
55 | if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
56 | echo "Run inference. Please make sure utt in tts_text is in prompt_data"
57 | for mode in sft zero_shot; do
58 | python cosyvoice/bin/inference.py --mode $mode \
59 | --gpu 0 \
60 | --config conf/cosyvoice.yaml \
61 | --prompt_data data/test/parquet/data.list \
62 | --prompt_utt2data data/test/parquet/utt2data.list \
63 | --tts_text `pwd`/tts_text.json \
64 | --llm_model $pretrained_model_dir/llm.pt \
65 | --flow_model $pretrained_model_dir/flow.pt \
66 | --hifigan_model $pretrained_model_dir/hift.pt \
67 | --result_dir `pwd`/exp/cosyvoice/test/$mode
68 | done
69 | fi
70 |
71 | # train llm
72 | export CUDA_VISIBLE_DEVICES="0,1,2,3"
73 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
74 | job_id=1986
75 | dist_backend="nccl"
76 | num_workers=2
77 | prefetch=100
78 | train_engine=torch_ddp
79 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
80 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml"
81 | if [ $train_engine == 'deepspeed' ]; then
82 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary"
83 | fi
84 | cp data/train/parquet/data.list data/train.data.list
85 | cp data/dev/parquet/data.list data/dev.data.list
86 | for model in llm flow hifigan; do
87 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
89 | cosyvoice/bin/train.py \
90 | --train_engine $train_engine \
91 | --config conf/cosyvoice.yaml \
92 | --train_data data/train.data.list \
93 | --cv_data data/dev.data.list \
94 | --model $model \
95 | --checkpoint $pretrained_model_dir/$model.pt \
96 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \
97 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \
98 | --ddp.dist_backend $dist_backend \
99 | --num_workers ${num_workers} \
100 | --prefetch ${prefetch} \
101 | --pin_memory \
102 | --use_amp \
103 | --deepspeed_config ./conf/ds_stage2.json \
104 | --deepspeed.save_states model+optimizer
105 | done
106 | fi
107 |
108 | # average model
109 | average_num=5
110 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
111 | for model in llm flow hifigan; do
112 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt
113 | echo "do model average and final checkpoint is $decode_checkpoint"
114 | python cosyvoice/bin/average_model.py \
115 | --dst_model $decode_checkpoint \
116 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \
117 | --num ${average_num} \
118 | --val_best
119 | done
120 | fi
121 |
122 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
123 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir"
124 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir
125 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir
126 | fi
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/tools:
--------------------------------------------------------------------------------
1 | ../../../tools
--------------------------------------------------------------------------------
/examples/magicdata-read/cosyvoice/tts_text.json:
--------------------------------------------------------------------------------
1 | {
2 | "38_5718_20170915093303": [
3 | "我想这出最好歌曲把歌词发到网上请别人帮我作曲急急",
4 | "叫他明天早上差五分儿九点去机场"
5 | ],
6 | "38_5721_20170915091235": [
7 | "变温室调到零下两度档",
8 | "交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件"
9 | ],
10 | "38_5733_20170915130323": [
11 | "这是老鹰乐队的一首经典歌曲",
12 | "我急用这段音乐我自己找到一段但是有现场杂音"
13 | ],
14 | "38_5836_20170916221414": [
15 | "给我播一个陶喆的专辑",
16 | "这套餐好贵呀我发这么多短信贵死了"
17 | ]
18 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu121
2 | --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684
3 | conformer==0.3.2
4 | deepspeed==0.15.1; sys_platform == 'linux'
5 | diffusers==0.29.0
6 | fastapi==0.115.6
7 | fastapi-cli==0.0.4
8 | gdown==5.1.0
9 | gradio==5.4.0
10 | grpcio==1.57.0
11 | grpcio-tools==1.57.0
12 | hydra-core==1.3.2
13 | HyperPyYAML==1.2.2
14 | inflect==7.3.1
15 | librosa==0.10.2
16 | lightning==2.2.4
17 | matplotlib==3.7.5
18 | modelscope==1.20.0
19 | networkx==3.1
20 | omegaconf==2.3.0
21 | onnx==1.16.0
22 | onnxruntime-gpu==1.18.0; sys_platform == 'linux'
23 | onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32'
24 | openai-whisper==20231117
25 | protobuf==4.25
26 | pyarrow==18.1.0
27 | pydantic==2.7.0
28 | pyworld==0.3.4
29 | rich==13.7.1
30 | soundfile==0.12.1
31 | tensorboard==2.14.0
32 | tensorrt-cu12==10.0.1; sys_platform == 'linux'
33 | tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux'
34 | tensorrt-cu12-libs==10.0.1; sys_platform == 'linux'
35 | torch==2.3.1
36 | torchaudio==2.3.1
37 | transformers==4.40.1
38 | uvicorn==0.30.0
39 | WeTextProcessing==1.0.3
40 | wget==3.2
41 |
--------------------------------------------------------------------------------
/runtime/python/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime
2 | ENV DEBIAN_FRONTEND=noninteractive
3 |
4 | WORKDIR /opt/CosyVoice
5 |
6 | RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list
7 | RUN apt-get update -y
8 | RUN apt-get -y install git unzip git-lfs g++
9 | RUN git lfs install
10 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
11 | # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed
12 | RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
13 | RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto
--------------------------------------------------------------------------------
/runtime/python/fastapi/client.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import argparse
15 | import logging
16 | import requests
17 | import torch
18 | import torchaudio
19 | import numpy as np
20 |
21 |
22 | def main():
23 | url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode)
24 | if args.mode == 'sft':
25 | payload = {
26 | 'tts_text': args.tts_text,
27 | 'spk_id': args.spk_id
28 | }
29 | response = requests.request("GET", url, data=payload, stream=True)
30 | elif args.mode == 'zero_shot':
31 | payload = {
32 | 'tts_text': args.tts_text,
33 | 'prompt_text': args.prompt_text
34 | }
35 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
36 | response = requests.request("GET", url, data=payload, files=files, stream=True)
37 | elif args.mode == 'cross_lingual':
38 | payload = {
39 | 'tts_text': args.tts_text,
40 | }
41 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
42 | response = requests.request("GET", url, data=payload, files=files, stream=True)
43 | else:
44 | payload = {
45 | 'tts_text': args.tts_text,
46 | 'spk_id': args.spk_id,
47 | 'instruct_text': args.instruct_text
48 | }
49 | response = requests.request("GET", url, data=payload, stream=True)
50 | tts_audio = b''
51 | for r in response.iter_content(chunk_size=16000):
52 | tts_audio += r
53 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
54 | logging.info('save response to {}'.format(args.tts_wav))
55 | torchaudio.save(args.tts_wav, tts_speech, target_sr)
56 | logging.info('get response')
57 |
58 |
59 | if __name__ == "__main__":
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument('--host',
62 | type=str,
63 | default='0.0.0.0')
64 | parser.add_argument('--port',
65 | type=int,
66 | default='50000')
67 | parser.add_argument('--mode',
68 | default='sft',
69 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
70 | help='request mode')
71 | parser.add_argument('--tts_text',
72 | type=str,
73 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
74 | parser.add_argument('--spk_id',
75 | type=str,
76 | default='中文女')
77 | parser.add_argument('--prompt_text',
78 | type=str,
79 | default='希望你以后能够做的比我还好呦。')
80 | parser.add_argument('--prompt_wav',
81 | type=str,
82 | default='../../../asset/zero_shot_prompt.wav')
83 | parser.add_argument('--instruct_text',
84 | type=str,
85 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
86 | Fights with fervor for justice, but struggles with impulsiveness.')
87 | parser.add_argument('--tts_wav',
88 | type=str,
89 | default='demo.wav')
90 | args = parser.parse_args()
91 | prompt_sr, target_sr = 16000, 22050
92 | main()
93 |
--------------------------------------------------------------------------------
/runtime/python/fastapi/server.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 | import argparse
17 | import logging
18 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
19 | from fastapi import FastAPI, UploadFile, Form, File
20 | from fastapi.responses import StreamingResponse
21 | from fastapi.middleware.cors import CORSMiddleware
22 | import uvicorn
23 | import numpy as np
24 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
25 | sys.path.append('{}/../../..'.format(ROOT_DIR))
26 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
27 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
28 | from cosyvoice.utils.file_utils import load_wav
29 |
30 | app = FastAPI()
31 | # set cross region allowance
32 | app.add_middleware(
33 | CORSMiddleware,
34 | allow_origins=["*"],
35 | allow_credentials=True,
36 | allow_methods=["*"],
37 | allow_headers=["*"])
38 |
39 |
40 | def generate_data(model_output):
41 | for i in model_output:
42 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
43 | yield tts_audio
44 |
45 |
46 | @app.get("/inference_sft")
47 | @app.post("/inference_sft")
48 | async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
49 | model_output = cosyvoice.inference_sft(tts_text, spk_id)
50 | return StreamingResponse(generate_data(model_output))
51 |
52 |
53 | @app.get("/inference_zero_shot")
54 | @app.post("/inference_zero_shot")
55 | async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
56 | prompt_speech_16k = load_wav(prompt_wav.file, 16000)
57 | model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
58 | return StreamingResponse(generate_data(model_output))
59 |
60 |
61 | @app.get("/inference_cross_lingual")
62 | @app.post("/inference_cross_lingual")
63 | async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
64 | prompt_speech_16k = load_wav(prompt_wav.file, 16000)
65 | model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
66 | return StreamingResponse(generate_data(model_output))
67 |
68 |
69 | @app.get("/inference_instruct")
70 | @app.post("/inference_instruct")
71 | async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
72 | model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
73 | return StreamingResponse(generate_data(model_output))
74 |
75 |
76 | @app.get("/inference_instruct2")
77 | @app.post("/inference_instruct2")
78 | async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()):
79 | prompt_speech_16k = load_wav(prompt_wav.file, 16000)
80 | model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k)
81 | return StreamingResponse(generate_data(model_output))
82 |
83 |
84 | if __name__ == '__main__':
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--port',
87 | type=int,
88 | default=50000)
89 | parser.add_argument('--model_dir',
90 | type=str,
91 | default='iic/CosyVoice-300M',
92 | help='local path or modelscope repo id')
93 | args = parser.parse_args()
94 | try:
95 | cosyvoice = CosyVoice(args.model_dir)
96 | except Exception:
97 | try:
98 | cosyvoice = CosyVoice2(args.model_dir)
99 | except Exception:
100 | raise TypeError('no valid model_type!')
101 | uvicorn.run(app, host="0.0.0.0", port=args.port)
102 |
--------------------------------------------------------------------------------
/runtime/python/grpc/client.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17 | sys.path.append('{}/../../..'.format(ROOT_DIR))
18 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
19 | import logging
20 | import argparse
21 | import torchaudio
22 | import cosyvoice_pb2
23 | import cosyvoice_pb2_grpc
24 | import grpc
25 | import torch
26 | import numpy as np
27 | from cosyvoice.utils.file_utils import load_wav
28 |
29 |
30 | def main():
31 | with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel:
32 | stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel)
33 | request = cosyvoice_pb2.Request()
34 | if args.mode == 'sft':
35 | logging.info('send sft request')
36 | sft_request = cosyvoice_pb2.sftRequest()
37 | sft_request.spk_id = args.spk_id
38 | sft_request.tts_text = args.tts_text
39 | request.sft_request.CopyFrom(sft_request)
40 | elif args.mode == 'zero_shot':
41 | logging.info('send zero_shot request')
42 | zero_shot_request = cosyvoice_pb2.zeroshotRequest()
43 | zero_shot_request.tts_text = args.tts_text
44 | zero_shot_request.prompt_text = args.prompt_text
45 | prompt_speech = load_wav(args.prompt_wav, 16000)
46 | zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
47 | request.zero_shot_request.CopyFrom(zero_shot_request)
48 | elif args.mode == 'cross_lingual':
49 | logging.info('send cross_lingual request')
50 | cross_lingual_request = cosyvoice_pb2.crosslingualRequest()
51 | cross_lingual_request.tts_text = args.tts_text
52 | prompt_speech = load_wav(args.prompt_wav, 16000)
53 | cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes()
54 | request.cross_lingual_request.CopyFrom(cross_lingual_request)
55 | else:
56 | logging.info('send instruct request')
57 | instruct_request = cosyvoice_pb2.instructRequest()
58 | instruct_request.tts_text = args.tts_text
59 | instruct_request.spk_id = args.spk_id
60 | instruct_request.instruct_text = args.instruct_text
61 | request.instruct_request.CopyFrom(instruct_request)
62 |
63 | response = stub.Inference(request)
64 | tts_audio = b''
65 | for r in response:
66 | tts_audio += r.tts_audio
67 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0)
68 | logging.info('save response to {}'.format(args.tts_wav))
69 | torchaudio.save(args.tts_wav, tts_speech, target_sr)
70 | logging.info('get response')
71 |
72 |
73 | if __name__ == "__main__":
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--host',
76 | type=str,
77 | default='0.0.0.0')
78 | parser.add_argument('--port',
79 | type=int,
80 | default='50000')
81 | parser.add_argument('--mode',
82 | default='sft',
83 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'],
84 | help='request mode')
85 | parser.add_argument('--tts_text',
86 | type=str,
87 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?')
88 | parser.add_argument('--spk_id',
89 | type=str,
90 | default='中文女')
91 | parser.add_argument('--prompt_text',
92 | type=str,
93 | default='希望你以后能够做的比我还好呦。')
94 | parser.add_argument('--prompt_wav',
95 | type=str,
96 | default='../../../asset/zero_shot_prompt.wav')
97 | parser.add_argument('--instruct_text',
98 | type=str,
99 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
100 | Fights with fervor for justice, but struggles with impulsiveness.')
101 | parser.add_argument('--tts_wav',
102 | type=str,
103 | default='demo.wav')
104 | args = parser.parse_args()
105 | prompt_sr, target_sr = 16000, 22050
106 | main()
107 |
--------------------------------------------------------------------------------
/runtime/python/grpc/cosyvoice.proto:
--------------------------------------------------------------------------------
1 | syntax = "proto3";
2 |
3 | package cosyvoice;
4 | option go_package = "protos/";
5 |
6 | service CosyVoice{
7 | rpc Inference(Request) returns (stream Response) {}
8 | }
9 |
10 | message Request{
11 | oneof RequestPayload {
12 | sftRequest sft_request = 1;
13 | zeroshotRequest zero_shot_request = 2;
14 | crosslingualRequest cross_lingual_request = 3;
15 | instructRequest instruct_request = 4;
16 | }
17 | }
18 |
19 | message sftRequest{
20 | string spk_id = 1;
21 | string tts_text = 2;
22 | }
23 |
24 | message zeroshotRequest{
25 | string tts_text = 1;
26 | string prompt_text = 2;
27 | bytes prompt_audio = 3;
28 | }
29 |
30 | message crosslingualRequest{
31 | string tts_text = 1;
32 | bytes prompt_audio = 2;
33 | }
34 |
35 | message instructRequest{
36 | string tts_text = 1;
37 | string spk_id = 2;
38 | string instruct_text = 3;
39 | }
40 |
41 | message Response{
42 | bytes tts_audio = 1;
43 | }
--------------------------------------------------------------------------------
/runtime/python/grpc/server.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import os
15 | import sys
16 | from concurrent import futures
17 | import argparse
18 | import cosyvoice_pb2
19 | import cosyvoice_pb2_grpc
20 | import logging
21 | logging.getLogger('matplotlib').setLevel(logging.WARNING)
22 | import grpc
23 | import torch
24 | import numpy as np
25 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26 | sys.path.append('{}/../../..'.format(ROOT_DIR))
27 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
28 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2
29 |
30 | logging.basicConfig(level=logging.DEBUG,
31 | format='%(asctime)s %(levelname)s %(message)s')
32 |
33 |
34 | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
35 | def __init__(self, args):
36 | try:
37 | self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc)
38 | except Exception:
39 | try:
40 | self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc)
41 | except Exception:
42 | raise TypeError('no valid model_type!')
43 | logging.info('grpc service initialized')
44 |
45 | def Inference(self, request, context):
46 | if request.HasField('sft_request'):
47 | logging.info('get sft inference request')
48 | model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id)
49 | elif request.HasField('zero_shot_request'):
50 | logging.info('get zero_shot inference request')
51 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
52 | prompt_speech_16k = prompt_speech_16k.float() / (2**15)
53 | model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
54 | request.zero_shot_request.prompt_text,
55 | prompt_speech_16k)
56 | elif request.HasField('cross_lingual_request'):
57 | logging.info('get cross_lingual inference request')
58 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
59 | prompt_speech_16k = prompt_speech_16k.float() / (2**15)
60 | model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
61 | else:
62 | logging.info('get instruct inference request')
63 | model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
64 | request.instruct_request.spk_id,
65 | request.instruct_request.instruct_text)
66 |
67 | logging.info('send inference response')
68 | for i in model_output:
69 | response = cosyvoice_pb2.Response()
70 | response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
71 | yield response
72 |
73 |
74 | def main():
75 | grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
76 | cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
77 | grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port))
78 | grpcServer.start()
79 | logging.info("server listening on 0.0.0.0:{}".format(args.port))
80 | grpcServer.wait_for_termination()
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--port',
86 | type=int,
87 | default=50000)
88 | parser.add_argument('--max_conc',
89 | type=int,
90 | default=4)
91 | parser.add_argument('--model_dir',
92 | type=str,
93 | default='iic/CosyVoice-300M',
94 | help='local path or modelscope repo id')
95 | args = parser.parse_args()
96 | main()
97 |
--------------------------------------------------------------------------------
/tools/extract_embedding.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import argparse
16 | from concurrent.futures import ThreadPoolExecutor, as_completed
17 | import onnxruntime
18 | import torch
19 | import torchaudio
20 | import torchaudio.compliance.kaldi as kaldi
21 | from tqdm import tqdm
22 |
23 |
24 | def single_job(utt):
25 | audio, sample_rate = torchaudio.load(utt2wav[utt])
26 | if sample_rate != 16000:
27 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
28 | feat = kaldi.fbank(audio,
29 | num_mel_bins=80,
30 | dither=0,
31 | sample_frequency=16000)
32 | feat = feat - feat.mean(dim=0, keepdim=True)
33 | embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
34 | return utt, embedding
35 |
36 |
37 | def main(args):
38 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
39 | utt2embedding, spk2embedding = {}, {}
40 | for future in tqdm(as_completed(all_task)):
41 | utt, embedding = future.result()
42 | utt2embedding[utt] = embedding
43 | spk = utt2spk[utt]
44 | if spk not in spk2embedding:
45 | spk2embedding[spk] = []
46 | spk2embedding[spk].append(embedding)
47 | for k, v in spk2embedding.items():
48 | spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist()
49 | torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir))
50 | torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir))
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--dir", type=str)
56 | parser.add_argument("--onnx_path", type=str)
57 | parser.add_argument("--num_thread", type=int, default=8)
58 | args = parser.parse_args()
59 |
60 | utt2wav, utt2spk = {}, {}
61 | with open('{}/wav.scp'.format(args.dir)) as f:
62 | for l in f:
63 | l = l.replace('\n', '').split()
64 | utt2wav[l[0]] = l[1]
65 | with open('{}/utt2spk'.format(args.dir)) as f:
66 | for l in f:
67 | l = l.replace('\n', '').split()
68 | utt2spk[l[0]] = l[1]
69 |
70 | option = onnxruntime.SessionOptions()
71 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
72 | option.intra_op_num_threads = 1
73 | providers = ["CPUExecutionProvider"]
74 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
75 | executor = ThreadPoolExecutor(max_workers=args.num_thread)
76 |
77 | main(args)
78 |
--------------------------------------------------------------------------------
/tools/extract_speech_token.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import argparse
16 | from concurrent.futures import ThreadPoolExecutor, as_completed
17 | import logging
18 | import torch
19 | from tqdm import tqdm
20 | import onnxruntime
21 | import numpy as np
22 | import torchaudio
23 | import whisper
24 |
25 |
26 | def single_job(utt):
27 | audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile')
28 | if sample_rate != 16000:
29 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio)
30 | # Convert audio to mono
31 | if audio.shape[0] > 1:
32 | audio = audio.mean(dim=0, keepdim=True)
33 | if audio.shape[1] / 16000 > 30:
34 | logging.warning('do not support extract speech token for audio longer than 30s')
35 | speech_token = []
36 | else:
37 | feat = whisper.log_mel_spectrogram(audio, n_mels=128)
38 | speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
39 | ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
40 | return utt, speech_token
41 |
42 |
43 | def main(args):
44 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()]
45 | utt2speech_token = {}
46 | for future in tqdm(as_completed(all_task)):
47 | utt, speech_token = future.result()
48 | utt2speech_token[utt] = speech_token
49 | torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir))
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--dir", type=str)
55 | parser.add_argument("--onnx_path", type=str)
56 | parser.add_argument("--num_thread", type=int, default=8)
57 | args = parser.parse_args()
58 |
59 | utt2wav = {}
60 | with open('{}/wav.scp'.format(args.dir)) as f:
61 | for l in f:
62 | l = l.replace('\n', '').split()
63 | utt2wav[l[0]] = l[1]
64 |
65 | option = onnxruntime.SessionOptions()
66 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
67 | option.intra_op_num_threads = 1
68 | providers = ["CUDAExecutionProvider"]
69 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers)
70 | executor = ThreadPoolExecutor(max_workers=args.num_thread)
71 |
72 | main(args)
73 |
--------------------------------------------------------------------------------
/tools/make_parquet_list.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import argparse
16 | import logging
17 | import os
18 | import json
19 | from tqdm import tqdm
20 | import pandas as pd
21 | import multiprocessing
22 | import time
23 | import torch
24 |
25 |
26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
27 | start_time = time.time()
28 | data_list = []
29 | for utt in tqdm(utt_list):
30 | data = open(utt2wav[utt], 'rb').read()
31 | data_list.append(data)
32 | wav_list = [utt2wav[utt] for utt in utt_list]
33 | text_list = [utt2text[utt] for utt in utt_list]
34 | spk_list = [utt2spk[utt] for utt in utt_list]
35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list]
36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
37 | speech_token_list = [utt2speech_token[utt] for utt in utt_list]
38 |
39 | # 保存到parquet,utt2parquet_file,spk2parquet_file
40 | df = pd.DataFrame()
41 | df['utt'] = utt_list
42 | df['wav'] = wav_list
43 | df['audio_data'] = data_list
44 | df['text'] = text_list
45 | df['spk'] = spk_list
46 | df['utt_embedding'] = uttembedding_list
47 | df['spk_embedding'] = spkembedding_list
48 | df['speech_token'] = speech_token_list
49 | df.to_parquet(parquet_file)
50 | with open(utt2parquet_file, 'w') as f:
51 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
52 | with open(spk2parquet_file, 'w') as f:
53 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54 | logging.info('spend time {}'.format(time.time() - start_time))
55 |
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument('--num_utts_per_parquet',
60 | type=int,
61 | default=1000,
62 | help='num utts per parquet')
63 | parser.add_argument('--num_processes',
64 | type=int,
65 | default=1,
66 | help='num processes for make parquets')
67 | parser.add_argument('--src_dir',
68 | type=str)
69 | parser.add_argument('--des_dir',
70 | type=str)
71 | args = parser.parse_args()
72 |
73 | utt2wav, utt2text, utt2spk = {}, {}, {}
74 | with open('{}/wav.scp'.format(args.src_dir)) as f:
75 | for l in f:
76 | l = l.replace('\n', '').split()
77 | utt2wav[l[0]] = l[1]
78 | with open('{}/text'.format(args.src_dir)) as f:
79 | for l in f:
80 | l = l.replace('\n', '').split()
81 | utt2text[l[0]] = ' '.join(l[1:])
82 | with open('{}/utt2spk'.format(args.src_dir)) as f:
83 | for l in f:
84 | l = l.replace('\n', '').split()
85 | utt2spk[l[0]] = l[1]
86 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
87 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
88 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
89 | utts = list(utt2wav.keys())
90 |
91 | # Using process pool to speedup
92 | pool = multiprocessing.Pool(processes=args.num_processes)
93 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
94 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
95 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
96 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
97 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
98 | parquet_list.append(parquet_file)
99 | utt2parquet_list.append(utt2parquet_file)
100 | spk2parquet_list.append(spk2parquet_file)
101 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
102 | pool.close()
103 | pool.join()
104 |
105 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
106 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
107 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
108 | for name in parquet_list:
109 | f1.write(name + '\n')
110 | for name in utt2parquet_list:
111 | f2.write(name + '\n')
112 | for name in spk2parquet_list:
113 | f3.write(name + '\n')
114 |
--------------------------------------------------------------------------------
/tools/make_parquet_list_dpo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import argparse
16 | import logging
17 | import os
18 | import json
19 | from tqdm import tqdm
20 | import pandas as pd
21 | import multiprocessing
22 | import time
23 | import torch
24 |
25 |
26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
27 | start_time = time.time()
28 | data_list = []
29 | for utt in tqdm(utt_list):
30 | data = open(utt2wav[utt], 'rb').read()
31 | data_list.append(data)
32 | wav_list = [utt2wav[utt] for utt in utt_list]
33 | text_list = [utt2text[utt] for utt in utt_list]
34 | spk_list = [utt2spk[utt] for utt in utt_list]
35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list]
36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list]
37 | speech_token_list = [utt2speech_token[utt] for utt in utt_list]
38 | if utt2reject_speech_token:
39 | reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list]
40 |
41 | # 保存到parquet,utt2parquet_file,spk2parquet_file
42 | df = pd.DataFrame()
43 | df['utt'] = utt_list
44 | df['wav'] = wav_list
45 | df['audio_data'] = data_list
46 | df['text'] = text_list
47 | df['spk'] = spk_list
48 | df['utt_embedding'] = uttembedding_list
49 | df['spk_embedding'] = spkembedding_list
50 | df['speech_token'] = speech_token_list
51 | if utt2reject_speech_token:
52 | df['reject_speech_token'] = reject_speech_token_list
53 | df.to_parquet(parquet_file)
54 | with open(utt2parquet_file, 'w') as f:
55 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2)
56 | with open(spk2parquet_file, 'w') as f:
57 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
58 | logging.info('spend time {}'.format(time.time() - start_time))
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument('--num_utts_per_parquet',
64 | type=int,
65 | default=1000,
66 | help='num utts per parquet')
67 | parser.add_argument('--num_processes',
68 | type=int,
69 | default=1,
70 | help='num processes for make parquets')
71 | parser.add_argument('--src_dir',
72 | type=str)
73 | parser.add_argument('--des_dir',
74 | type=str)
75 | parser.add_argument('--dpo',
76 | action='store_true',
77 | default=False,
78 | help='Use Direct Preference Optimization')
79 | args = parser.parse_args()
80 |
81 | utt2wav, utt2text, utt2spk = {}, {}, {}
82 | with open('{}/wav.scp'.format(args.src_dir)) as f:
83 | for l in f:
84 | l = l.replace('\n', '').split()
85 | utt2wav[l[0]] = l[1]
86 | with open('{}/text'.format(args.src_dir)) as f:
87 | for l in f:
88 | l = l.replace('\n', '').split()
89 | utt2text[l[0]] = ' '.join(l[1:])
90 | with open('{}/utt2spk'.format(args.src_dir)) as f:
91 | for l in f:
92 | l = l.replace('\n', '').split()
93 | utt2spk[l[0]] = l[1]
94 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir))
95 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir))
96 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir))
97 | if args.dpo:
98 | utt2reject_speech_token = torch.load('{}/utt2reject_speech_token.pt'.format(args.src_dir))
99 | else:
100 | utt2reject_speech_token = None
101 | utts = list(utt2wav.keys())
102 |
103 | # Using process pool to speedup
104 | pool = multiprocessing.Pool(processes=args.num_processes)
105 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], []
106 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)):
107 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i))
108 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i))
109 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i))
110 | parquet_list.append(parquet_file)
111 | utt2parquet_list.append(utt2parquet_file)
112 | spk2parquet_list.append(spk2parquet_file)
113 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file))
114 | pool.close()
115 | pool.join()
116 |
117 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \
118 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \
119 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3:
120 | for name in parquet_list:
121 | f1.write(name + '\n')
122 | for name in utt2parquet_list:
123 | f2.write(name + '\n')
124 | for name in spk2parquet_list:
125 | f3.write(name + '\n')
126 |
--------------------------------------------------------------------------------
/vllm_example.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('third_party/Matcha-TTS')
3 | from vllm import ModelRegistry
4 | from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM
5 | ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM)
6 |
7 | from cosyvoice.cli.cosyvoice import CosyVoice2
8 | from cosyvoice.utils.file_utils import load_wav
9 | from cosyvoice.utils.common import set_all_random_seed
10 | from tqdm import tqdm
11 |
12 | def main():
13 | cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True)
14 | prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000)
15 | for i in tqdm(range(100)):
16 | set_all_random_seed(i)
17 | for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
18 | continue
19 |
20 | if __name__=='__main__':
21 | main()
22 |
--------------------------------------------------------------------------------