├── cosyvoice ├── __init__.py ├── cli │ └── __init__.py ├── utils │ ├── __init__.py │ ├── losses.py │ ├── class_utils.py │ ├── frontend_utils.py │ ├── file_utils.py │ └── common.py ├── dataset │ ├── __init__.py │ └── dataset.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── decoder_layer.py │ └── convolution.py ├── hifigan │ ├── f0_predictor.py │ └── hifigan.py ├── flow │ └── length_regulator.py ├── bin │ ├── average_model.py │ ├── export_jit.py │ ├── export_onnx.py │ └── inference_deprecated.py ├── vllm │ └── cosyvoice2.py └── tokenizer │ └── tokenizer.py ├── examples ├── libritts │ ├── cosyvoice │ │ ├── tools │ │ ├── cosyvoice │ │ ├── tts_text.json │ │ ├── path.sh │ │ ├── conf │ │ │ └── ds_stage2.json │ │ ├── local │ │ │ ├── prepare_data.py │ │ │ ├── prepare_reject_sample.py │ │ │ └── download_and_untar.sh │ │ └── run.sh │ └── cosyvoice2 │ │ ├── tools │ │ ├── local │ │ ├── cosyvoice │ │ ├── path.sh │ │ ├── tts_text.json │ │ ├── conf │ │ ├── ds_stage2.json │ │ └── cosyvoice2.yaml │ │ ├── run.sh │ │ └── run_dpo.sh └── magicdata-read │ └── cosyvoice │ ├── tools │ ├── cosyvoice │ ├── conf │ ├── path.sh │ ├── tts_text.json │ ├── local │ ├── prepare_data.py │ └── download_and_untar.sh │ └── run.sh ├── runtime ├── triton_trtllm │ ├── model_repo │ │ ├── tensorrt_llm │ │ │ └── 1 │ │ │ │ └── .gitkeep │ │ ├── audio_tokenizer │ │ │ ├── 1 │ │ │ │ └── model.py │ │ │ └── config.pbtxt │ │ ├── token2wav │ │ │ └── config.pbtxt │ │ └── cosyvoice2 │ │ │ └── config.pbtxt │ ├── requirements.txt │ ├── Dockerfile.server │ ├── docker-compose.yml │ ├── scripts │ │ ├── fill_template.py │ │ └── test_llm.py │ ├── README.md │ ├── run.sh │ └── client_http.py └── python │ ├── Dockerfile │ ├── grpc │ ├── cosyvoice.proto │ ├── server.py │ └── client.py │ └── fastapi │ ├── client.py │ └── server.py ├── asset ├── dingding.png ├── zero_shot_prompt.wav └── cross_lingual_prompt.wav ├── .gitmodules ├── .github ├── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md └── workflows │ ├── stale-issues.yml │ └── lint.yml ├── FAQ.md ├── .gitignore ├── vllm_example.py ├── requirements.txt ├── docker └── Dockerfile ├── tools ├── extract_speech_token.py ├── extract_embedding.py └── make_parquet_list.py └── CODE_OF_CONDUCT.md /cosyvoice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/local: -------------------------------------------------------------------------------- 1 | ../cosyvoice/local -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/tools: -------------------------------------------------------------------------------- 1 | ../../../tools -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/path.sh: -------------------------------------------------------------------------------- 1 | ../cosyvoice/path.sh -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/tensorrt_llm/1/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/cosyvoice: -------------------------------------------------------------------------------- 1 | ../../../cosyvoice -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/tts_text.json: -------------------------------------------------------------------------------- 1 | ../cosyvoice/tts_text.json -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/conf: -------------------------------------------------------------------------------- 1 | ../../libritts/cosyvoice/conf -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/path.sh: -------------------------------------------------------------------------------- 1 | ../../libritts/cosyvoice/path.sh -------------------------------------------------------------------------------- /asset/dingding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devops/CosyVoice/main/asset/dingding.png -------------------------------------------------------------------------------- /asset/zero_shot_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devops/CosyVoice/main/asset/zero_shot_prompt.wav -------------------------------------------------------------------------------- /asset/cross_lingual_prompt.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/devops/CosyVoice/main/asset/cross_lingual_prompt.wav -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Matcha-TTS"] 2 | path = third_party/Matcha-TTS 3 | url = https://github.com/shivammehta25/Matcha-TTS.git 4 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/tts_text.json: -------------------------------------------------------------------------------- 1 | { 2 | "1089_134686_000002_000000": [ 3 | "hello, my name is Jack. What is your name?" 4 | ] 5 | } -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/path.sh: -------------------------------------------------------------------------------- 1 | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C 2 | export PYTHONIOENCODING=UTF-8 3 | export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH 4 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/requirements.txt: -------------------------------------------------------------------------------- 1 | hyperpyyaml 2 | s3tokenizer 3 | onnxruntime-gpu 4 | omegaconf 5 | conformer 6 | hydra-core 7 | lightning 8 | gdown 9 | wget 10 | librosa 11 | pyworld 12 | openai-whisper 13 | tritonclient 14 | modelscope 15 | -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/tts_text.json: -------------------------------------------------------------------------------- 1 | { 2 | "38_5718_20170915093303": [ 3 | "我想这出最好歌曲把歌词发到网上请别人帮我作曲急急", 4 | "叫他明天早上差五分儿九点去机场" 5 | ], 6 | "38_5721_20170915091235": [ 7 | "变温室调到零下两度档", 8 | "交谈中请勿轻信汇款信息陌生电话请勿使用外挂软件" 9 | ], 10 | "38_5733_20170915130323": [ 11 | "这是老鹰乐队的一首经典歌曲", 12 | "我急用这段音乐我自己找到一段但是有现场杂音" 13 | ], 14 | "38_5836_20170916221414": [ 15 | "给我播一个陶喆的专辑", 16 | "这套餐好贵呀我发这么多短信贵死了" 17 | ] 18 | } -------------------------------------------------------------------------------- /runtime/triton_trtllm/Dockerfile.server: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/tritonserver:25.06-trtllm-python-py3 2 | LABEL maintainer="zhangyuekai@foxmail.com" 3 | 4 | RUN apt-get update && apt-get install -y cmake 5 | RUN git clone https://github.com/pytorch/audio.git && cd audio && git checkout c670ad8 && PATH=/usr/local/cuda/bin:$PATH python3 setup.py develop 6 | COPY ./requirements.txt /workspace/requirements.txt 7 | RUN pip install -r /workspace/requirements.txt 8 | WORKDIR /workspace -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## ModuleNotFoundError: No module named 'matcha' 2 | 3 | Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`. 4 | 5 | run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script. 6 | 7 | ## cannot find resource.zip or cannot unzip resource.zip 8 | 9 | Please make sure you have git-lfs installed. Execute 10 | 11 | ```sh 12 | git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd 13 | cd pretrained_models/CosyVoice-ttsfrd/ 14 | unzip resource.zip -d . 15 | pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl 16 | ``` 17 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | tts: 3 | image: soar97/triton-cosyvoice:25.06 4 | shm_size: '1gb' 5 | ports: 6 | - "8000:8000" 7 | - "8001:8001" 8 | - "8002:8002" 9 | environment: 10 | - PYTHONIOENCODING=utf-8 11 | - MODEL_ID=${MODEL_ID} 12 | deploy: 13 | resources: 14 | reservations: 15 | devices: 16 | - driver: nvidia 17 | device_ids: ['0'] 18 | capabilities: [gpu] 19 | command: > 20 | /bin/bash -c "pip install modelscope && cd /workspace && git clone https://github.com/FunAudioLLM/CosyVoice.git && cd CosyVoice && git submodule update --init --recursive && cd runtime/triton_trtllm && bash run.sh 0 3" -------------------------------------------------------------------------------- /runtime/python/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | WORKDIR /opt/CosyVoice 5 | 6 | RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list 7 | RUN apt-get update -y 8 | RUN apt-get -y install git unzip git-lfs g++ 9 | RUN git lfs install 10 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 11 | # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed 12 | RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 13 | RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Visual Studio Code files 7 | .vscode 8 | .vs 9 | 10 | # PyCharm files 11 | .idea 12 | 13 | # Eclipse Project settings 14 | *.*project 15 | .settings 16 | 17 | # Sublime Text settings 18 | *.sublime-workspace 19 | *.sublime-project 20 | 21 | # Editor temporaries 22 | *.swn 23 | *.swo 24 | *.swp 25 | *.swm 26 | *~ 27 | 28 | # IPython notebook checkpoints 29 | .ipynb_checkpoints 30 | 31 | # macOS dir files 32 | .DS_Store 33 | 34 | exp 35 | data 36 | raw_wav 37 | tensorboard 38 | **/*build* 39 | 40 | # Clangd files 41 | .cache 42 | compile_commands.json 43 | 44 | # train/inference files 45 | *.wav 46 | *.m4a 47 | *.aac 48 | *.pt 49 | pretrained_models/* 50 | *_pb2_grpc.py 51 | *_pb2.py 52 | *.tar -------------------------------------------------------------------------------- /.github/workflows/stale-issues.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v5 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /runtime/python/grpc/cosyvoice.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cosyvoice; 4 | option go_package = "protos/"; 5 | 6 | service CosyVoice{ 7 | rpc Inference(Request) returns (stream Response) {} 8 | } 9 | 10 | message Request{ 11 | oneof RequestPayload { 12 | sftRequest sft_request = 1; 13 | zeroshotRequest zero_shot_request = 2; 14 | crosslingualRequest cross_lingual_request = 3; 15 | instructRequest instruct_request = 4; 16 | } 17 | } 18 | 19 | message sftRequest{ 20 | string spk_id = 1; 21 | string tts_text = 2; 22 | } 23 | 24 | message zeroshotRequest{ 25 | string tts_text = 1; 26 | string prompt_text = 2; 27 | bytes prompt_audio = 3; 28 | } 29 | 30 | message crosslingualRequest{ 31 | string tts_text = 1; 32 | bytes prompt_audio = 2; 33 | } 34 | 35 | message instructRequest{ 36 | string tts_text = 1; 37 | string spk_id = 2; 38 | string instruct_text = 3; 39 | } 40 | 41 | message Response{ 42 | bytes tts_audio = 1; 43 | } -------------------------------------------------------------------------------- /vllm_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('third_party/Matcha-TTS') 3 | from vllm import ModelRegistry 4 | from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM 5 | ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM) 6 | 7 | from cosyvoice.cli.cosyvoice import CosyVoice2 8 | from cosyvoice.utils.file_utils import load_wav 9 | from cosyvoice.utils.common import set_all_random_seed 10 | from tqdm import tqdm 11 | 12 | 13 | def main(): 14 | cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True) 15 | prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) 16 | for i in tqdm(range(100)): 17 | set_all_random_seed(i) 18 | for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): 19 | continue 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/conf/ds_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 256, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": false 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 2, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | }, 33 | "optimizer": { 34 | "type": "AdamW", 35 | "params": { 36 | "lr": 0.001, 37 | "weight_decay": 0.0001, 38 | "torch_adam": true, 39 | "adam_w_mode": true 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/conf/ds_stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 1, 3 | "gradient_accumulation_steps": 1, 4 | "steps_per_print": 100, 5 | "gradient_clipping": 5, 6 | "fp16": { 7 | "enabled": false, 8 | "auto_cast": false, 9 | "loss_scale": 0, 10 | "initial_scale_power": 16, 11 | "loss_scale_window": 256, 12 | "hysteresis": 2, 13 | "consecutive_hysteresis": false, 14 | "min_loss_scale": 1 15 | }, 16 | "bf16": { 17 | "enabled": false 18 | }, 19 | "zero_force_ds_cpu_optimizer": false, 20 | "zero_optimization": { 21 | "stage": 2, 22 | "offload_optimizer": { 23 | "device": "none", 24 | "pin_memory": true 25 | }, 26 | "allgather_partitions": true, 27 | "allgather_bucket_size": 5e8, 28 | "overlap_comm": false, 29 | "reduce_scatter": true, 30 | "reduce_bucket_size": 5e8, 31 | "contiguous_gradients" : true 32 | }, 33 | "optimizer": { 34 | "type": "AdamW", 35 | "params": { 36 | "lr": 0.001, 37 | "weight_decay": 0.0001, 38 | "torch_adam": true, 39 | "adam_w_mode": true 40 | } 41 | } 42 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684 3 | conformer==0.3.2 4 | deepspeed==0.15.1; sys_platform == 'linux' 5 | diffusers==0.29.0 6 | fastapi==0.115.6 7 | fastapi-cli==0.0.4 8 | gdown==5.1.0 9 | gradio==5.4.0 10 | grpcio==1.57.0 11 | grpcio-tools==1.57.0 12 | hydra-core==1.3.2 13 | HyperPyYAML==1.2.2 14 | inflect==7.3.1 15 | librosa==0.10.2 16 | lightning==2.2.4 17 | matplotlib==3.7.5 18 | modelscope==1.20.0 19 | networkx==3.1 20 | omegaconf==2.3.0 21 | onnx==1.16.0 22 | onnxruntime-gpu==1.18.0; sys_platform == 'linux' 23 | onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' 24 | openai-whisper==20231117 25 | protobuf==4.25 26 | pyarrow==18.1.0 27 | pydantic==2.7.0 28 | pyworld==0.3.4 29 | rich==13.7.1 30 | soundfile==0.12.1 31 | tensorboard==2.14.0 32 | tensorrt-cu12==10.0.1; sys_platform == 'linux' 33 | tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux' 34 | tensorrt-cu12-libs==10.0.1; sys_platform == 'linux' 35 | torch==2.3.1 36 | torchaudio==2.3.1 37 | transformers==4.51.3 38 | uvicorn==0.30.0 39 | wetext==0.0.4 40 | wget==3.2 41 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/audio_tokenizer/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "audio_tokenizer" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | parameters [ 22 | { 23 | key: "model_dir", 24 | value: {string_value:"${model_dir}"} 25 | } 26 | ] 27 | 28 | input [ 29 | { 30 | name: "reference_wav" 31 | data_type: TYPE_FP32 32 | dims: [-1] 33 | }, 34 | { 35 | name: "reference_wav_len" 36 | data_type: TYPE_INT32 37 | dims: [1] 38 | } 39 | ] 40 | output [ 41 | { 42 | name: "prompt_speech_tokens" 43 | data_type: TYPE_INT32 44 | dims: [-1] 45 | } 46 | ] 47 | 48 | instance_group [ 49 | { 50 | count: 1 51 | kind: KIND_CPU 52 | } 53 | ] -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/token2wav/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "token2wav" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | parameters [ 22 | { 23 | key: "model_dir", 24 | value: {string_value:"${model_dir}"} 25 | } 26 | ] 27 | 28 | input [ 29 | { 30 | name: "target_speech_tokens" 31 | data_type: TYPE_INT32 32 | dims: [-1] 33 | }, 34 | { 35 | name: "prompt_speech_tokens" 36 | data_type: TYPE_INT32 37 | dims: [-1] 38 | }, 39 | { 40 | name: "prompt_speech_feat" 41 | data_type: TYPE_FP16 42 | dims: [-1, 80] 43 | }, 44 | { 45 | name: "prompt_spk_embedding" 46 | data_type: TYPE_FP16 47 | dims: [-1] 48 | } 49 | ] 50 | output [ 51 | { 52 | name: "waveform" 53 | data_type: TYPE_FP32 54 | dims: [ -1 ] 55 | } 56 | ] 57 | 58 | instance_group [ 59 | { 60 | count: 1 61 | kind: KIND_CPU 62 | } 63 | ] -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/cosyvoice2/config.pbtxt: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: "cosyvoice2" 16 | backend: "python" 17 | max_batch_size: ${triton_max_batch_size} 18 | dynamic_batching { 19 | max_queue_delay_microseconds: ${max_queue_delay_microseconds} 20 | } 21 | model_transaction_policy { 22 | decoupled: ${decoupled_mode} 23 | } 24 | parameters [ 25 | { 26 | key: "llm_tokenizer_dir", 27 | value: {string_value:"${llm_tokenizer_dir}"} 28 | }, 29 | { 30 | key: "model_dir", 31 | value: {string_value:"${model_dir}"} 32 | } 33 | ] 34 | 35 | input [ 36 | { 37 | name: "reference_wav" 38 | data_type: TYPE_FP32 39 | dims: [-1] 40 | }, 41 | { 42 | name: "reference_wav_len" 43 | data_type: TYPE_INT32 44 | dims: [1] 45 | }, 46 | { 47 | name: "reference_text" 48 | data_type: TYPE_STRING 49 | dims: [1] 50 | }, 51 | { 52 | name: "target_text" 53 | data_type: TYPE_STRING 54 | dims: [1] 55 | } 56 | ] 57 | output [ 58 | { 59 | name: "waveform" 60 | data_type: TYPE_FP32 61 | dims: [ -1 ] 62 | } 63 | ] 64 | 65 | instance_group [ 66 | { 67 | count: ${bls_instance_num} 68 | kind: KIND_CPU 69 | } 70 | ] -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from tqdm import tqdm 5 | 6 | 7 | logger = logging.getLogger() 8 | 9 | 10 | def main(): 11 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {} 12 | with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f: 13 | lines = f.readlines()[1:] 14 | lines = [l.split('\t') for l in lines] 15 | for wav, spk, content in tqdm(lines): 16 | wav, spk, content = wav.strip(), spk.strip(), content.strip() 17 | content = content.replace('[FIL]', '') 18 | content = content.replace('[SPK]', '') 19 | wav = os.path.join(args.src_dir, spk, wav) 20 | if not os.path.exists(wav): 21 | continue 22 | utt = os.path.basename(wav).replace('.wav', '') 23 | utt2wav[utt] = wav 24 | utt2text[utt] = content 25 | utt2spk[utt] = spk 26 | if spk not in spk2utt: 27 | spk2utt[spk] = [] 28 | spk2utt[spk].append(utt) 29 | 30 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 31 | for k, v in utt2wav.items(): 32 | f.write('{} {}\n'.format(k, v)) 33 | with open('{}/text'.format(args.des_dir), 'w') as f: 34 | for k, v in utt2text.items(): 35 | f.write('{} {}\n'.format(k, v)) 36 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f: 37 | for k, v in utt2spk.items(): 38 | f.write('{} {}\n'.format(k, v)) 39 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f: 40 | for k, v in spk2utt.items(): 41 | f.write('{} {}\n'.format(k, ' '.join(v))) 42 | return 43 | 44 | 45 | if __name__ == "__main__": 46 | parser = argparse.ArgumentParser() 47 | parser.add_argument('--src_dir', 48 | type=str) 49 | parser.add_argument('--des_dir', 50 | type=str) 51 | args = parser.parse_args() 52 | main() 53 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/local/prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import glob 4 | import os 5 | from tqdm import tqdm 6 | 7 | 8 | logger = logging.getLogger() 9 | 10 | 11 | def main(): 12 | wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir))) 13 | 14 | utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {} 15 | for wav in tqdm(wavs): 16 | txt = wav.replace('.wav', '.normalized.txt') 17 | if not os.path.exists(txt): 18 | logger.warning('{} do not exsist'.format(txt)) 19 | continue 20 | with open(txt) as f: 21 | content = ''.join(l.replace('\n', '') for l in f.readline()) 22 | utt = os.path.basename(wav).replace('.wav', '') 23 | spk = utt.split('_')[0] 24 | utt2wav[utt] = wav 25 | utt2text[utt] = content 26 | utt2spk[utt] = spk 27 | if spk not in spk2utt: 28 | spk2utt[spk] = [] 29 | spk2utt[spk].append(utt) 30 | 31 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 32 | for k, v in utt2wav.items(): 33 | f.write('{} {}\n'.format(k, v)) 34 | with open('{}/text'.format(args.des_dir), 'w') as f: 35 | for k, v in utt2text.items(): 36 | f.write('{} {}\n'.format(k, v)) 37 | with open('{}/utt2spk'.format(args.des_dir), 'w') as f: 38 | for k, v in utt2spk.items(): 39 | f.write('{} {}\n'.format(k, v)) 40 | with open('{}/spk2utt'.format(args.des_dir), 'w') as f: 41 | for k, v in spk2utt.items(): 42 | f.write('{} {}\n'.format(k, ' '.join(v))) 43 | return 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = argparse.ArgumentParser() 48 | parser.add_argument('--src_dir', 49 | type=str) 50 | parser.add_argument('--des_dir', 51 | type=str) 52 | parser.add_argument('--ref_model', 53 | type=str) 54 | args = parser.parse_args() 55 | main() 56 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/local/prepare_reject_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from tqdm import tqdm 5 | import torch 6 | import torchaudio 7 | from cosyvoice.cli.cosyvoice import CosyVoice2 8 | from cosyvoice.utils.file_utils import load_wav 9 | 10 | 11 | logger = logging.getLogger() 12 | 13 | 14 | def main(): 15 | cosyvoice = CosyVoice2(args.ref_model) 16 | 17 | utt2wav, utt2text = {}, {} 18 | with open('{}/wav.scp'.format(args.src_dir)) as f: 19 | for l in f: 20 | l = l.split('\n')[0].split() 21 | utt2wav[l[0]] = l[1] 22 | with open('{}/text'.format(args.src_dir)) as f: 23 | for l in f: 24 | l = l.split('\n')[0].split() 25 | utt2text[l[0]] = ' '.join(l[1:]) 26 | 27 | os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True) 28 | with open('{}/wav.scp'.format(args.des_dir), 'w') as f: 29 | for utt, wav in tqdm(utt2wav.items()): 30 | prompt_speech_16k = load_wav(wav, 16000) 31 | if prompt_speech_16k.shape[1] >= 30 * 16000: 32 | continue 33 | speech_list = [] 34 | for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)): 35 | speech_list.append(j['tts_speech']) 36 | negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav))) 37 | torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile') 38 | f.write('{} {}\n'.format(utt, negative_wav)) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--src_dir', 44 | type=str) 45 | parser.add_argument('--des_dir', 46 | type=str) 47 | parser.add_argument('--ref_model', 48 | type=str) 49 | args = parser.parse_args() 50 | main() 51 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/scripts/fill_template.py: -------------------------------------------------------------------------------- 1 | # /usr/bin/env python3 2 | from argparse import ArgumentParser 3 | from string import Template 4 | 5 | 6 | def split(string, delimiter): 7 | """Split a string using delimiter. Supports escaping. 8 | 9 | Args: 10 | string (str): The string to split. 11 | delimiter (str): The delimiter to split the string with. 12 | 13 | Returns: 14 | list: A list of strings. 15 | """ 16 | result = [] 17 | current = "" 18 | escape = False 19 | for char in string: 20 | if escape: 21 | current += char 22 | escape = False 23 | elif char == delimiter: 24 | result.append(current) 25 | current = "" 26 | elif char == "\\": 27 | escape = True 28 | else: 29 | current += char 30 | result.append(current) 31 | return result 32 | 33 | 34 | def main(file_path, substitutions, in_place): 35 | with open(file_path) as f: 36 | pbtxt = Template(f.read()) 37 | 38 | sub_dict = { 39 | "max_queue_size": 0, 40 | 'max_queue_delay_microseconds': 0, 41 | } 42 | for sub in split(substitutions, ","): 43 | key, value = split(sub, ":") 44 | sub_dict[key] = value 45 | 46 | assert key in pbtxt.template, f"key '{key}' does not exist in the file {file_path}." 47 | 48 | pbtxt = pbtxt.safe_substitute(sub_dict) 49 | 50 | if in_place: 51 | with open(file_path, "w") as f: 52 | f.write(pbtxt) 53 | else: 54 | print(pbtxt) 55 | 56 | 57 | if __name__ == "__main__": 58 | parser = ArgumentParser() 59 | parser.add_argument("file_path", help="path of the .pbtxt to modify") 60 | parser.add_argument( 61 | "substitutions", 62 | help="substitutions to perform, in the format variable_name_1:value_1,variable_name_2:value_2..." 63 | ) 64 | parser.add_argument("--in_place", 65 | "-i", 66 | action="store_true", 67 | help="do the operation in-place") 68 | args = parser.parse_args() 69 | main(**vars(args)) 70 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 2 | 3 | ARG VENV_NAME="cosyvoice" 4 | ENV VENV=$VENV_NAME 5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 6 | 7 | ENV DEBIAN_FRONTEN=noninteractive 8 | ENV PYTHONUNBUFFERED=1 9 | SHELL ["/bin/bash", "--login", "-c"] 10 | 11 | RUN apt-get update -y --fix-missing 12 | RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \ 13 | apt-get clean && \ 14 | git lfs install 15 | 16 | # ================================================================== 17 | # conda install and conda forge channel as default 18 | # ------------------------------------------------------------------ 19 | # Install miniforge 20 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \ 21 | /bin/bash ~/miniforge.sh -b -p /opt/conda && \ 22 | rm ~/miniforge.sh && \ 23 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 24 | echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \ 25 | echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 26 | echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \ 27 | echo "conda activate ${VENV}" >> $HOME/.bashrc 28 | 29 | ENV PATH /opt/conda/bin:$PATH 30 | 31 | RUN conda config --add channels conda-forge && \ 32 | conda config --set channel_priority strict 33 | # ------------------------------------------------------------------ 34 | # ~conda 35 | # ================================================================== 36 | 37 | RUN conda create -y -n ${VENV} python=3.10 38 | ENV CONDA_DEFAULT_ENV=${VENV} 39 | ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH 40 | 41 | WORKDIR /workspace 42 | 43 | ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS" 44 | 45 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 46 | 47 | RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5 48 | RUN conda activate ${VENV} && cd CosyVoice && \ 49 | pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 50 | 51 | WORKDIR /workspace/CosyVoice 52 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | try: 17 | from torch.nn.utils.parametrizations import weight_norm 18 | except ImportError: 19 | from torch.nn.utils import weight_norm 20 | 21 | 22 | class ConvRNNF0Predictor(nn.Module): 23 | def __init__(self, 24 | num_class: int = 1, 25 | in_channels: int = 80, 26 | cond_channels: int = 512 27 | ): 28 | super().__init__() 29 | 30 | self.num_class = num_class 31 | self.condnet = nn.Sequential( 32 | weight_norm( 33 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 34 | ), 35 | nn.ELU(), 36 | weight_norm( 37 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 38 | ), 39 | nn.ELU(), 40 | weight_norm( 41 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 42 | ), 43 | nn.ELU(), 44 | weight_norm( 45 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 46 | ), 47 | nn.ELU(), 48 | weight_norm( 49 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 50 | ), 51 | nn.ELU(), 52 | ) 53 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | x = self.condnet(x) 57 | x = x.transpose(1, 2) 58 | return torch.abs(self.classifier(x).squeeze(-1)) 59 | -------------------------------------------------------------------------------- /cosyvoice/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple 4 | 5 | 6 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 7 | loss = 0 8 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 9 | m_DG = torch.median((dr - dg)) 10 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 11 | loss += tau - F.relu(tau - L_rel) 12 | return loss 13 | 14 | 15 | def mel_loss(real_speech, generated_speech, mel_transforms): 16 | loss = 0 17 | for transform in mel_transforms: 18 | mel_r = transform(real_speech) 19 | mel_g = transform(generated_speech) 20 | loss += F.l1_loss(mel_g, mel_r) 21 | return loss 22 | 23 | 24 | class DPOLoss(torch.nn.Module): 25 | """ 26 | DPO Loss 27 | """ 28 | 29 | def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: 30 | super().__init__() 31 | self.beta = beta 32 | self.label_smoothing = label_smoothing 33 | self.ipo = ipo 34 | 35 | def forward( 36 | self, 37 | policy_chosen_logps: torch.Tensor, 38 | policy_rejected_logps: torch.Tensor, 39 | reference_chosen_logps: torch.Tensor, 40 | reference_rejected_logps: torch.Tensor, 41 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 42 | pi_logratios = policy_chosen_logps - policy_rejected_logps 43 | ref_logratios = reference_chosen_logps - reference_rejected_logps 44 | logits = pi_logratios - ref_logratios 45 | if self.ipo: 46 | losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 47 | else: 48 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) 49 | losses = ( 50 | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) 51 | - F.logsigmoid(-self.beta * logits) * self.label_smoothing 52 | ) 53 | loss = losses.mean() 54 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 55 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() 56 | 57 | return loss, chosen_rewards, rejected_rewards 58 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | pull_request: 5 | push: 6 | 7 | jobs: 8 | quick-checks: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Fetch CosyVoice 12 | uses: actions/checkout@v1 13 | - name: Checkout PR tip 14 | run: | 15 | set -eux 16 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then 17 | # We are on a PR, so actions/checkout leaves us on a merge commit. 18 | # Check out the actual tip of the branch. 19 | git checkout ${{ github.event.pull_request.head.sha }} 20 | fi 21 | echo ::set-output name=commit_sha::$(git rev-parse HEAD) 22 | id: get_pr_tip 23 | - name: Ensure no tabs 24 | run: | 25 | (! git grep -I -l $'\t' -- . ':(exclude)*.txt' ':(exclude)*.svg' ':(exclude)**Makefile' ':(exclude)**/contrib/**' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have tabs; please convert them to spaces"; false)) 26 | - name: Ensure no trailing whitespace 27 | run: | 28 | (! git grep -I -n $' $' -- . ':(exclude)*.txt' ':(exclude)third_party' ':(exclude).gitattributes' ':(exclude).gitmodules' || (echo "The above files have trailing whitespace; please remove them"; false)) 29 | 30 | flake8-py3: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - name: Setup Python 34 | uses: actions/setup-python@v1 35 | with: 36 | python-version: 3.9 37 | architecture: x64 38 | - name: Fetch CosyVoice 39 | uses: actions/checkout@v1 40 | - name: Checkout PR tip 41 | run: | 42 | set -eux 43 | if [[ "${{ github.event_name }}" == "pull_request" ]]; then 44 | # We are on a PR, so actions/checkout leaves us on a merge commit. 45 | # Check out the actual tip of the branch. 46 | git checkout ${{ github.event.pull_request.head.sha }} 47 | fi 48 | echo ::set-output name=commit_sha::$(git rev-parse HEAD) 49 | id: get_pr_tip 50 | - name: Run flake8 51 | run: | 52 | set -eux 53 | pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0 54 | flake8 --version 55 | flake8 --max-line-length 180 --ignore B006,B008,B905,C408,E402,E731,E741,W503,W504,F401,F403,F405,F841 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py 56 | if [ $? != 0 ]; then exit 1; fi -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | remove_archive=false 7 | 8 | if [ "$1" == --remove-archive ]; then 9 | remove_archive=true 10 | shift 11 | fi 12 | 13 | if [ $# -ne 3 ]; then 14 | echo "Usage: $0 [--remove-archive] " 15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" 16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other," 18 | echo " train-clean-100, train-clean-360, train-other-500." 19 | exit 1 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1 29 | fi 30 | 31 | part_ok=false 32 | list="dev_set test_set train_set" 33 | for x in $list; do 34 | if [ "$part" == $x ]; then part_ok=true; fi 35 | done 36 | if ! $part_ok; then 37 | echo "$0: expected to be one of $list, but got '$part'" 38 | exit 1 39 | fi 40 | 41 | if [ -z "$url" ]; then 42 | echo "$0: empty URL base." 43 | exit 1 44 | fi 45 | 46 | if [ -f $data/.$part.complete ]; then 47 | echo "$0: data part $part was already successfully extracted, nothing to do." 48 | exit 0 49 | fi 50 | 51 | 52 | # sizes of the archive files in bytes. This is some older versions. 53 | sizes_old="1035537823 2201936013 52627842921" 54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of 55 | # things we probably won't download. 56 | sizes_new="3886385" 57 | 58 | if [ -f $data/$part.tar.gz ]; then 59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 60 | size_ok=false 61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done 62 | if ! $size_ok; then 63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 64 | echo "does not equal the size of one of the archives." 65 | rm $data/$part.tar.gz 66 | else 67 | echo "$data/$part.tar.gz exists and appears to be complete." 68 | fi 69 | fi 70 | 71 | if [ ! -f $data/$part.tar.gz ]; then 72 | if ! which wget >/dev/null; then 73 | echo "$0: wget is not installed." 74 | exit 1 75 | fi 76 | full_url=$url/$part.tar.gz 77 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 78 | 79 | if ! wget -P $data --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1 82 | fi 83 | fi 84 | 85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then 86 | echo "$0: error un-tarring archive $data/$part.tar.gz" 87 | exit 1 88 | fi 89 | 90 | touch $data/.$part.complete 91 | 92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 93 | 94 | if $remove_archive; then 95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 96 | rm $data/$part.tar.gz 97 | fi 98 | -------------------------------------------------------------------------------- /tools/extract_speech_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import logging 18 | import torch 19 | from tqdm import tqdm 20 | import onnxruntime 21 | import numpy as np 22 | import torchaudio 23 | import whisper 24 | 25 | 26 | def single_job(utt): 27 | audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile') 28 | if sample_rate != 16000: 29 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 30 | # Convert audio to mono 31 | if audio.shape[0] > 1: 32 | audio = audio.mean(dim=0, keepdim=True) 33 | if audio.shape[1] / 16000 > 30: 34 | logging.warning('do not support extract speech token for audio longer than 30s') 35 | speech_token = [] 36 | else: 37 | feat = whisper.log_mel_spectrogram(audio, n_mels=128) 38 | speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), 39 | ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() 40 | return utt, speech_token 41 | 42 | 43 | def main(args): 44 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 45 | utt2speech_token = {} 46 | for future in tqdm(as_completed(all_task)): 47 | utt, speech_token = future.result() 48 | utt2speech_token[utt] = speech_token 49 | torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--dir", type=str) 55 | parser.add_argument("--onnx_path", type=str) 56 | parser.add_argument("--num_thread", type=int, default=8) 57 | args = parser.parse_args() 58 | 59 | utt2wav = {} 60 | with open('{}/wav.scp'.format(args.dir)) as f: 61 | for l in f: 62 | l = l.replace('\n', '').split() 63 | utt2wav[l[0]] = l[1] 64 | 65 | option = onnxruntime.SessionOptions() 66 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 67 | option.intra_op_num_threads = 1 68 | providers = ["CUDAExecutionProvider"] 69 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 70 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 71 | 72 | main(args) 73 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/local/download_and_untar.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2014 Johns Hopkins University (author: Daniel Povey) 4 | # Apache 2.0 5 | 6 | remove_archive=false 7 | 8 | if [ "$1" == --remove-archive ]; then 9 | remove_archive=true 10 | shift 11 | fi 12 | 13 | if [ $# -ne 3 ]; then 14 | echo "Usage: $0 [--remove-archive] " 15 | echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean" 16 | echo "With --remove-archive it will remove the archive after successfully un-tarring it." 17 | echo " can be one of: dev-clean, test-clean, dev-other, test-other," 18 | echo " train-clean-100, train-clean-360, train-other-500." 19 | exit 1 20 | fi 21 | 22 | data=$1 23 | url=$2 24 | part=$3 25 | 26 | if [ ! -d "$data" ]; then 27 | echo "$0: no such directory $data" 28 | exit 1 29 | fi 30 | 31 | part_ok=false 32 | list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500" 33 | for x in $list; do 34 | if [ "$part" == $x ]; then part_ok=true; fi 35 | done 36 | if ! $part_ok; then 37 | echo "$0: expected to be one of $list, but got '$part'" 38 | exit 1 39 | fi 40 | 41 | if [ -z "$url" ]; then 42 | echo "$0: empty URL base." 43 | exit 1 44 | fi 45 | 46 | if [ -f $data/LibriTTS/$part/.complete ]; then 47 | echo "$0: data part $part was already successfully extracted, nothing to do." 48 | exit 0 49 | fi 50 | 51 | 52 | # sizes of the archive files in bytes. This is some older versions. 53 | sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128" 54 | # sizes_new is the archive file sizes of the final release. Some of these sizes are of 55 | # things we probably won't download. 56 | sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606" 57 | 58 | if [ -f $data/$part.tar.gz ]; then 59 | size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}') 60 | size_ok=false 61 | for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done 62 | if ! $size_ok; then 63 | echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size" 64 | echo "does not equal the size of one of the archives." 65 | rm $data/$part.tar.gz 66 | else 67 | echo "$data/$part.tar.gz exists and appears to be complete." 68 | fi 69 | fi 70 | 71 | if [ ! -f $data/$part.tar.gz ]; then 72 | if ! which wget >/dev/null; then 73 | echo "$0: wget is not installed." 74 | exit 1 75 | fi 76 | full_url=$url/$part.tar.gz 77 | echo "$0: downloading data from $full_url. This may take some time, please be patient." 78 | 79 | if ! wget -P $data --no-check-certificate $full_url; then 80 | echo "$0: error executing wget $full_url" 81 | exit 1 82 | fi 83 | fi 84 | 85 | if ! tar -C $data -xvzf $data/$part.tar.gz; then 86 | echo "$0: error un-tarring archive $data/$part.tar.gz" 87 | exit 1 88 | fi 89 | 90 | touch $data/LibriTTS/$part/.complete 91 | 92 | echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz" 93 | 94 | if $remove_archive; then 95 | echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied." 96 | rm $data/$part.tar.gz 97 | fi 98 | -------------------------------------------------------------------------------- /tools/extract_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import onnxruntime 18 | import torch 19 | import torchaudio 20 | import torchaudio.compliance.kaldi as kaldi 21 | from tqdm import tqdm 22 | 23 | 24 | def single_job(utt): 25 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 26 | if sample_rate != 16000: 27 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 28 | feat = kaldi.fbank(audio, 29 | num_mel_bins=80, 30 | dither=0, 31 | sample_frequency=16000) 32 | feat = feat - feat.mean(dim=0, keepdim=True) 33 | embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() 34 | return utt, embedding 35 | 36 | 37 | def main(args): 38 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 39 | utt2embedding, spk2embedding = {}, {} 40 | for future in tqdm(as_completed(all_task)): 41 | utt, embedding = future.result() 42 | utt2embedding[utt] = embedding 43 | spk = utt2spk[utt] 44 | if spk not in spk2embedding: 45 | spk2embedding[spk] = [] 46 | spk2embedding[spk].append(embedding) 47 | for k, v in spk2embedding.items(): 48 | spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() 49 | torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) 50 | torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--dir", type=str) 56 | parser.add_argument("--onnx_path", type=str) 57 | parser.add_argument("--num_thread", type=int, default=8) 58 | args = parser.parse_args() 59 | 60 | utt2wav, utt2spk = {}, {} 61 | with open('{}/wav.scp'.format(args.dir)) as f: 62 | for l in f: 63 | l = l.replace('\n', '').split() 64 | utt2wav[l[0]] = l[1] 65 | with open('{}/utt2spk'.format(args.dir)) as f: 66 | for l in f: 67 | l = l.replace('\n', '').split() 68 | utt2spk[l[0]] = l[1] 69 | 70 | option = onnxruntime.SessionOptions() 71 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 72 | option.intra_op_num_threads = 1 73 | providers = ["CPUExecutionProvider"] 74 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 75 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py 55 | # x in (B, T, D) 56 | if x2.shape[1] > 40: 57 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 58 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 59 | mode='linear') 60 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 61 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 62 | else: 63 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 64 | if x1.shape[1] != 0: 65 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 66 | x = torch.concat([x1, x2], dim=2) 67 | else: 68 | x = x2 69 | out = self.model(x).transpose(1, 2).contiguous() 70 | return out, mel_len1 + mel_len2 71 | -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss 6 | from cosyvoice.utils.losses import tpr_loss, mel_loss 7 | 8 | 9 | class HiFiGan(nn.Module): 10 | def __init__(self, generator, discriminator, mel_spec_transform, 11 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, 12 | tpr_loss_weight=1.0, tpr_loss_tau=0.04): 13 | super(HiFiGan, self).__init__() 14 | self.generator = generator 15 | self.discriminator = discriminator 16 | self.mel_spec_transform = mel_spec_transform 17 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight 18 | self.feat_match_loss_weight = feat_match_loss_weight 19 | self.tpr_loss_weight = tpr_loss_weight 20 | self.tpr_loss_tau = tpr_loss_tau 21 | 22 | def forward( 23 | self, 24 | batch: dict, 25 | device: torch.device, 26 | ) -> Dict[str, Optional[torch.Tensor]]: 27 | if batch['turn'] == 'generator': 28 | return self.forward_generator(batch, device) 29 | else: 30 | return self.forward_discriminator(batch, device) 31 | 32 | def forward_generator(self, batch, device): 33 | real_speech = batch['speech'].to(device) 34 | pitch_feat = batch['pitch_feat'].to(device) 35 | # 1. calculate generator outputs 36 | generated_speech, generated_f0 = self.generator(batch, device) 37 | # 2. calculate discriminator outputs 38 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 39 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] 40 | loss_gen, _ = generator_loss(y_d_gs) 41 | loss_fm = feature_loss(fmap_rs, fmap_gs) 42 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) 43 | if self.tpr_loss_weight != 0: 44 | loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau) 45 | else: 46 | loss_tpr = torch.zeros(1).to(device) 47 | loss_f0 = F.l1_loss(generated_f0, pitch_feat) 48 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ 49 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \ 50 | self.tpr_loss_weight * loss_tpr + loss_f0 51 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} 52 | 53 | def forward_discriminator(self, batch, device): 54 | real_speech = batch['speech'].to(device) 55 | # 1. calculate generator outputs 56 | with torch.no_grad(): 57 | generated_speech, generated_f0 = self.generator(batch, device) 58 | # 2. calculate discriminator outputs 59 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach()) 60 | # 3. calculate discriminator losses, tpr losses [Optional] 61 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) 62 | if self.tpr_loss_weight != 0: 63 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 64 | else: 65 | loss_tpr = torch.zeros(1).to(device) 66 | loss = loss_disc + self.tpr_loss_weight * loss_tpr 67 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} 68 | -------------------------------------------------------------------------------- /cosyvoice/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import glob 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser(description='average model') 26 | parser.add_argument('--dst_model', required=True, help='averaged model') 27 | parser.add_argument('--src_path', 28 | required=True, 29 | help='src model path for average') 30 | parser.add_argument('--val_best', 31 | action="store_true", 32 | help='averaged model') 33 | parser.add_argument('--num', 34 | default=5, 35 | type=int, 36 | help='nums for averaged model') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | return args 41 | 42 | 43 | def main(): 44 | args = get_args() 45 | val_scores = [] 46 | if args.val_best: 47 | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) 48 | yamls = [ 49 | f for f in yamls 50 | if not (os.path.basename(f).startswith('train') 51 | or os.path.basename(f).startswith('init')) 52 | ] 53 | for y in yamls: 54 | with open(y, 'r') as f: 55 | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) 56 | loss = float(dic_yaml['loss_dict']['loss']) 57 | epoch = int(dic_yaml['epoch']) 58 | step = int(dic_yaml['step']) 59 | tag = dic_yaml['tag'] 60 | val_scores += [[epoch, step, loss, tag]] 61 | sorted_val_scores = sorted(val_scores, 62 | key=lambda x: x[2], 63 | reverse=False) 64 | print("best val (epoch, step, loss, tag) = " + 65 | str(sorted_val_scores[:args.num])) 66 | path_list = [ 67 | args.src_path + '/epoch_{}_whole.pt'.format(score[0]) 68 | for score in sorted_val_scores[:args.num] 69 | ] 70 | print(path_list) 71 | avg = {} 72 | num = args.num 73 | assert num == len(path_list) 74 | for path in path_list: 75 | print('Processing {}'.format(path)) 76 | states = torch.load(path, map_location=torch.device('cpu')) 77 | for k in states.keys(): 78 | if k not in ['step', 'epoch']: 79 | if k not in avg.keys(): 80 | avg[k] = states[k].clone() 81 | else: 82 | avg[k] += states[k] 83 | # average 84 | for k in avg.keys(): 85 | if avg[k] is not None: 86 | # pytorch 1.6 use true_divide instead of /= 87 | avg[k] = torch.true_divide(avg[k], num) 88 | print('Saving to {}'.format(args.dst_model)) 89 | torch.save(avg, args.dst_model) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at mikelei@mobvoi.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from cosyvoice.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding 34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling 35 | from cosyvoice.llm.llm import TransformerLM, Qwen2LM 36 | from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec 37 | from cosyvoice.hifigan.generator import HiFTGenerator 38 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 39 | 40 | 41 | COSYVOICE_ACTIVATION_CLASSES = { 42 | "hardtanh": torch.nn.Hardtanh, 43 | "tanh": torch.nn.Tanh, 44 | "relu": torch.nn.ReLU, 45 | "selu": torch.nn.SELU, 46 | "swish": getattr(torch.nn, "SiLU", Swish), 47 | "gelu": torch.nn.GELU, 48 | } 49 | 50 | COSYVOICE_SUBSAMPLE_CLASSES = { 51 | "linear": LinearNoSubsampling, 52 | "linear_legacy": LegacyLinearNoSubsampling, 53 | "embed": EmbedinigNoSubsampling, 54 | "conv1d2": Conv1dSubsampling2, 55 | "conv2d": Conv2dSubsampling4, 56 | "conv2d6": Conv2dSubsampling6, 57 | "conv2d8": Conv2dSubsampling8, 58 | 'paraformer_dummy': torch.nn.Identity 59 | } 60 | 61 | COSYVOICE_EMB_CLASSES = { 62 | "embed": PositionalEncoding, 63 | "abs_pos": PositionalEncoding, 64 | "rel_pos": RelPositionalEncoding, 65 | "rel_pos_espnet": EspnetRelPositionalEncoding, 66 | "no_pos": NoPositionalEncoding, 67 | "abs_pos_whisper": WhisperPositionalEncoding, 68 | "embed_learnable_pe": LearnablePositionalEncoding, 69 | } 70 | 71 | COSYVOICE_ATTENTION_CLASSES = { 72 | "selfattn": MultiHeadedAttention, 73 | "rel_selfattn": RelPositionMultiHeadedAttention, 74 | } 75 | 76 | 77 | def get_model_type(configs): 78 | # NOTE CosyVoice2Model inherits CosyVoiceModel 79 | if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 80 | return CosyVoiceModel 81 | if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 82 | return CosyVoice2Model 83 | raise TypeError('No valid model type found!') 84 | -------------------------------------------------------------------------------- /cosyvoice/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /runtime/python/fastapi/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import logging 16 | import requests 17 | import torch 18 | import torchaudio 19 | import numpy as np 20 | 21 | 22 | def main(): 23 | url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode) 24 | if args.mode == 'sft': 25 | payload = { 26 | 'tts_text': args.tts_text, 27 | 'spk_id': args.spk_id 28 | } 29 | response = requests.request("GET", url, data=payload, stream=True) 30 | elif args.mode == 'zero_shot': 31 | payload = { 32 | 'tts_text': args.tts_text, 33 | 'prompt_text': args.prompt_text 34 | } 35 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 36 | response = requests.request("GET", url, data=payload, files=files, stream=True) 37 | elif args.mode == 'cross_lingual': 38 | payload = { 39 | 'tts_text': args.tts_text, 40 | } 41 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 42 | response = requests.request("GET", url, data=payload, files=files, stream=True) 43 | else: 44 | payload = { 45 | 'tts_text': args.tts_text, 46 | 'spk_id': args.spk_id, 47 | 'instruct_text': args.instruct_text 48 | } 49 | response = requests.request("GET", url, data=payload, stream=True) 50 | tts_audio = b'' 51 | for r in response.iter_content(chunk_size=16000): 52 | tts_audio += r 53 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 54 | logging.info('save response to {}'.format(args.tts_wav)) 55 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 56 | logging.info('get response') 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--host', 62 | type=str, 63 | default='0.0.0.0') 64 | parser.add_argument('--port', 65 | type=int, 66 | default='50000') 67 | parser.add_argument('--mode', 68 | default='sft', 69 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 70 | help='request mode') 71 | parser.add_argument('--tts_text', 72 | type=str, 73 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 74 | parser.add_argument('--spk_id', 75 | type=str, 76 | default='中文女') 77 | parser.add_argument('--prompt_text', 78 | type=str, 79 | default='希望你以后能够做的比我还好呦。') 80 | parser.add_argument('--prompt_wav', 81 | type=str, 82 | default='../../../asset/zero_shot_prompt.wav') 83 | parser.add_argument('--instruct_text', 84 | type=str, 85 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 86 | Fights with fervor for justice, but struggles with impulsiveness.') 87 | parser.add_argument('--tts_wav', 88 | type=str, 89 | default='demo.wav') 90 | args = parser.parse_args() 91 | prompt_sr, target_sr = 16000, 22050 92 | main() 93 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/model_repo/audio_tokenizer/1/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import json 27 | import torch 28 | from torch.utils.dlpack import to_dlpack 29 | 30 | import triton_python_backend_utils as pb_utils 31 | 32 | import os 33 | import numpy as np 34 | import s3tokenizer 35 | 36 | ORIGINAL_VOCAB_SIZE = 151663 37 | 38 | 39 | class TritonPythonModel: 40 | """Triton Python model for audio tokenization. 41 | 42 | This model takes reference audio input and extracts semantic tokens 43 | using s3tokenizer. 44 | """ 45 | 46 | def initialize(self, args): 47 | """Initialize the model. 48 | 49 | Args: 50 | args: Dictionary containing model configuration 51 | """ 52 | # Parse model parameters 53 | parameters = json.loads(args['model_config'])['parameters'] 54 | model_params = {k: v["string_value"] for k, v in parameters.items()} 55 | 56 | self.device = torch.device("cuda") 57 | model_path = os.path.join(model_params["model_dir"], "speech_tokenizer_v2.onnx") 58 | self.audio_tokenizer = s3tokenizer.load_model(model_path).to(self.device) 59 | 60 | def execute(self, requests): 61 | """Execute inference on the batched requests. 62 | 63 | Args: 64 | requests: List of inference requests 65 | 66 | Returns: 67 | List of inference responses containing tokenized outputs 68 | """ 69 | mels = [] 70 | 71 | # Process each request in batch 72 | for request in requests: 73 | # Extract input tensors 74 | wav_array = pb_utils.get_input_tensor_by_name( 75 | request, "reference_wav").as_numpy() 76 | wav_len = pb_utils.get_input_tensor_by_name( 77 | request, "reference_wav_len").as_numpy().item() 78 | 79 | wav_array = torch.from_numpy(wav_array).to(self.device) 80 | # Prepare inputs 81 | wav = wav_array[:, :wav_len].squeeze(0) 82 | mels.append(s3tokenizer.log_mel_spectrogram(wav)) 83 | 84 | mels, mels_lens = s3tokenizer.padding(mels) 85 | codes, codes_lens = self.audio_tokenizer.quantize(mels.to(self.device), mels_lens.to(self.device)) 86 | codes = codes.clone() + ORIGINAL_VOCAB_SIZE 87 | 88 | responses = [] 89 | for i in range(len(requests)): 90 | prompt_speech_tokens = codes[i, :codes_lens[i].item()] 91 | prompt_speech_tokens_tensor = pb_utils.Tensor.from_dlpack( 92 | "prompt_speech_tokens", to_dlpack(prompt_speech_tokens)) 93 | inference_response = pb_utils.InferenceResponse( 94 | output_tensors=[prompt_speech_tokens_tensor]) 95 | responses.append(inference_response) 96 | 97 | return responses 98 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 27 | from cosyvoice.utils.file_utils import logging 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='export your model for deployment') 32 | parser.add_argument('--model_dir', 33 | type=str, 34 | default='pretrained_models/CosyVoice-300M', 35 | help='local path') 36 | args = parser.parse_args() 37 | print(args) 38 | return args 39 | 40 | 41 | def get_optimized_script(model, preserved_attrs=[]): 42 | script = torch.jit.script(model) 43 | if preserved_attrs != []: 44 | script = torch.jit.freeze(script, preserved_attrs=preserved_attrs) 45 | else: 46 | script = torch.jit.freeze(script) 47 | script = torch.jit.optimize_for_inference(script) 48 | return script 49 | 50 | 51 | def main(): 52 | args = get_args() 53 | logging.basicConfig(level=logging.DEBUG, 54 | format='%(asctime)s %(levelname)s %(message)s') 55 | 56 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 57 | torch._C._jit_set_profiling_mode(False) 58 | torch._C._jit_set_profiling_executor(False) 59 | 60 | try: 61 | model = CosyVoice(args.model_dir) 62 | except Exception: 63 | try: 64 | model = CosyVoice2(args.model_dir) 65 | except Exception: 66 | raise TypeError('no valid model_type!') 67 | 68 | if not isinstance(model, CosyVoice2): 69 | # 1. export llm text_encoder 70 | llm_text_encoder = model.model.llm.text_encoder 71 | script = get_optimized_script(llm_text_encoder) 72 | script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) 73 | script = get_optimized_script(llm_text_encoder.half()) 74 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 75 | logging.info('successfully export llm_text_encoder') 76 | 77 | # 2. export llm llm 78 | llm_llm = model.model.llm.llm 79 | script = get_optimized_script(llm_llm, ['forward_chunk']) 80 | script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) 81 | script = get_optimized_script(llm_llm.half(), ['forward_chunk']) 82 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 83 | logging.info('successfully export llm_llm') 84 | 85 | # 3. export flow encoder 86 | flow_encoder = model.model.flow.encoder 87 | script = get_optimized_script(flow_encoder) 88 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 89 | script = get_optimized_script(flow_encoder.half()) 90 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 91 | logging.info('successfully export flow_encoder') 92 | else: 93 | # 3. export flow encoder 94 | flow_encoder = model.model.flow.encoder 95 | script = get_optimized_script(flow_encoder) 96 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 97 | script = get_optimized_script(flow_encoder.half()) 98 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 99 | logging.info('successfully export flow_encoder') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /runtime/python/fastapi/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import logging 18 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 19 | from fastapi import FastAPI, UploadFile, Form, File 20 | from fastapi.responses import StreamingResponse 21 | from fastapi.middleware.cors import CORSMiddleware 22 | import uvicorn 23 | import numpy as np 24 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 26 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 27 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 28 | from cosyvoice.utils.file_utils import load_wav 29 | 30 | app = FastAPI() 31 | # set cross region allowance 32 | app.add_middleware( 33 | CORSMiddleware, 34 | allow_origins=["*"], 35 | allow_credentials=True, 36 | allow_methods=["*"], 37 | allow_headers=["*"]) 38 | 39 | 40 | def generate_data(model_output): 41 | for i in model_output: 42 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 43 | yield tts_audio 44 | 45 | 46 | @app.get("/inference_sft") 47 | @app.post("/inference_sft") 48 | async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): 49 | model_output = cosyvoice.inference_sft(tts_text, spk_id) 50 | return StreamingResponse(generate_data(model_output)) 51 | 52 | 53 | @app.get("/inference_zero_shot") 54 | @app.post("/inference_zero_shot") 55 | async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()): 56 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 57 | model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) 58 | return StreamingResponse(generate_data(model_output)) 59 | 60 | 61 | @app.get("/inference_cross_lingual") 62 | @app.post("/inference_cross_lingual") 63 | async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): 64 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 65 | model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) 66 | return StreamingResponse(generate_data(model_output)) 67 | 68 | 69 | @app.get("/inference_instruct") 70 | @app.post("/inference_instruct") 71 | async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): 72 | model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) 73 | return StreamingResponse(generate_data(model_output)) 74 | 75 | 76 | @app.get("/inference_instruct2") 77 | @app.post("/inference_instruct2") 78 | async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()): 79 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 80 | model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) 81 | return StreamingResponse(generate_data(model_output)) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--port', 87 | type=int, 88 | default=50000) 89 | parser.add_argument('--model_dir', 90 | type=str, 91 | default='iic/CosyVoice-300M', 92 | help='local path or modelscope repo id') 93 | args = parser.parse_args() 94 | try: 95 | cosyvoice = CosyVoice(args.model_dir) 96 | except Exception: 97 | try: 98 | cosyvoice = CosyVoice2(args.model_dir) 99 | except Exception: 100 | raise TypeError('no valid model_type!') 101 | uvicorn.run(app, host="0.0.0.0", port=args.port) 102 | -------------------------------------------------------------------------------- /examples/magicdata-read/cosyvoice/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/68 9 | data_dir=/mnt/hengwu.zty/data/tts/openslr/magicdata-read 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev_set test_set train_set; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in dev test train; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in dev test train; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in dev test train; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in dev test train; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # train llm 55 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 56 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 57 | job_id=1986 58 | dist_backend="nccl" 59 | num_workers=2 60 | prefetch=100 61 | train_engine=torch_ddp 62 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 63 | echo "Run train. We only support llm traning for now. If your want to train from scratch, please use conf/cosyvoice.fromscratch.yaml" 64 | if [ $train_engine == 'deepspeed' ]; then 65 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 66 | fi 67 | cp data/train/parquet/data.list data/train.data.list 68 | cp data/dev/parquet/data.list data/dev.data.list 69 | for model in llm flow hifigan; do 70 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 71 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \ 72 | cosyvoice/bin/train.py \ 73 | --train_engine $train_engine \ 74 | --config conf/cosyvoice.yaml \ 75 | --train_data data/train.data.list \ 76 | --cv_data data/dev.data.list \ 77 | --model $model \ 78 | --checkpoint $pretrained_model_dir/$model.pt \ 79 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \ 80 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \ 81 | --ddp.dist_backend $dist_backend \ 82 | --num_workers ${num_workers} \ 83 | --prefetch ${prefetch} \ 84 | --pin_memory \ 85 | --use_amp \ 86 | --deepspeed_config ./conf/ds_stage2.json \ 87 | --deepspeed.save_states model+optimizer 88 | done 89 | fi 90 | 91 | # average model 92 | average_num=5 93 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 94 | for model in llm flow hifigan; do 95 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 96 | echo "do model average and final checkpoint is $decode_checkpoint" 97 | python cosyvoice/bin/average_model.py \ 98 | --dst_model $decode_checkpoint \ 99 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 100 | --num ${average_num} \ 101 | --val_best 102 | done 103 | fi 104 | 105 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 106 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 107 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 108 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 109 | fi -------------------------------------------------------------------------------- /cosyvoice/vllm/cosyvoice2.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # Adapted from 4 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py 5 | # Copyright 2024 The Qwen team. 6 | # Copyright 2023 The vLLM team. 7 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 8 | # 9 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 10 | # and OPT implementations in this library. It has been modified from its 11 | # original forms to accommodate minor architectural differences compared 12 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | """Inference-only Qwen2 model compatible with HuggingFace weights.""" 26 | from vllm.model_executor.models.qwen2 import * 27 | 28 | 29 | class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): 30 | packed_modules_mapping = { 31 | "qkv_proj": [ 32 | "q_proj", 33 | "k_proj", 34 | "v_proj", 35 | ], 36 | "gate_up_proj": [ 37 | "gate_proj", 38 | "up_proj", 39 | ], 40 | } 41 | 42 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): 43 | super().__init__() 44 | config = vllm_config.model_config.hf_config 45 | quant_config = vllm_config.quant_config 46 | lora_config = vllm_config.lora_config 47 | 48 | self.config = config 49 | self.lora_config = lora_config 50 | 51 | self.quant_config = quant_config 52 | self.model = Qwen2Model(vllm_config=vllm_config, 53 | prefix=maybe_prefix(prefix, "model")) 54 | 55 | if get_pp_group().is_last_rank: 56 | if config.tie_word_embeddings: 57 | self.lm_head = self.model.embed_tokens 58 | else: 59 | self.lm_head = ParallelLMHead(config.vocab_size, 60 | config.hidden_size, 61 | True, 62 | quant_config=quant_config, 63 | prefix=maybe_prefix( 64 | prefix, "lm_head")) 65 | else: 66 | self.lm_head = PPMissingLayer() 67 | 68 | self.logits_processor = LogitsProcessor(config.vocab_size) 69 | 70 | self.make_empty_intermediate_tensors = ( 71 | self.model.make_empty_intermediate_tensors) 72 | 73 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 74 | return self.model.get_input_embeddings(input_ids) 75 | 76 | def forward( 77 | self, 78 | input_ids: torch.Tensor, 79 | positions: torch.Tensor, 80 | intermediate_tensors: Optional[IntermediateTensors] = None, 81 | inputs_embeds: Optional[torch.Tensor] = None, 82 | ) -> Union[torch.Tensor, IntermediateTensors]: 83 | hidden_states = self.model(input_ids, positions, intermediate_tensors, 84 | inputs_embeds) 85 | return hidden_states 86 | 87 | def compute_logits( 88 | self, 89 | hidden_states: torch.Tensor, 90 | sampling_metadata: SamplingMetadata, 91 | ) -> Optional[torch.Tensor]: 92 | logits = self.logits_processor(self.lm_head, hidden_states, 93 | sampling_metadata, self.lm_head.bias) 94 | return logits 95 | 96 | def load_weights(self, weights: Iterable[tuple[str, 97 | torch.Tensor]]) -> set[str]: 98 | loader = AutoWeightsLoader( 99 | self, 100 | skip_prefixes=(["lm_head."] 101 | if self.config.tie_word_embeddings else None), 102 | ) 103 | return loader.load_weights(weights) 104 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/README.md: -------------------------------------------------------------------------------- 1 | ## Best Practices for Serving CosyVoice with NVIDIA Triton Inference Server 2 | 3 | Thanks to the contribution from NVIDIA Yuekai Zhang. 4 | 5 | ### Quick Start 6 | Launch the service directly with Docker Compose: 7 | ```sh 8 | docker compose up 9 | ``` 10 | 11 | ### Build the Docker Image 12 | Build the image from scratch: 13 | ```sh 14 | docker build . -f Dockerfile.server -t soar97/triton-cosyvoice:25.06 15 | ``` 16 | 17 | ### Run a Docker Container 18 | ```sh 19 | your_mount_dir=/mnt:/mnt 20 | docker run -it --name "cosyvoice-server" --gpus all --net host -v $your_mount_dir --shm-size=2g soar97/triton-cosyvoice:25.06 21 | ``` 22 | 23 | ### Understanding `run.sh` 24 | The `run.sh` script orchestrates the entire workflow through numbered stages. 25 | 26 | Run a subset of stages with: 27 | ```sh 28 | bash run.sh [service_type] 29 | ``` 30 | - `` – stage to start from (0-5). 31 | - `` – stage to stop after (0-5). 32 | 33 | Stages: 34 | - **Stage 0** – Download the cosyvoice-2 0.5B model from HuggingFace. 35 | - **Stage 1** – Convert the HuggingFace checkpoint to TensorRT-LLM format and build TensorRT engines. 36 | - **Stage 2** – Create the Triton model repository and configure the model files (adjusts depending on whether `Decoupled=True/False` will be used later). 37 | - **Stage 3** – Launch the Triton Inference Server. 38 | - **Stage 4** – Run the single-utterance HTTP client. 39 | - **Stage 5** – Run the gRPC benchmark client. 40 | 41 | ### Export Models to TensorRT-LLM and Launch the Server 42 | Inside the Docker container, prepare the models and start the Triton server by running stages 0-3: 43 | ```sh 44 | # Runs stages 0, 1, 2, and 3 45 | bash run.sh 0 3 46 | ``` 47 | *Note: Stage 2 prepares the model repository differently depending on whether you intend to run with `Decoupled=False` or `Decoupled=True`. Rerun stage 2 if you switch the service type.* 48 | 49 | ### Single-Utterance HTTP Client 50 | Send a single HTTP inference request: 51 | ```sh 52 | bash run.sh 4 4 53 | ``` 54 | 55 | ### Benchmark with a Dataset 56 | Benchmark the running Triton server. Pass either `streaming` or `offline` as the third argument. 57 | ```sh 58 | bash run.sh 5 5 59 | 60 | # You can also customise parameters such as num_task and dataset split directly: 61 | # python3 client_grpc.py --num-tasks 2 --huggingface-dataset yuekai/seed_tts_cosy2 --split-name test_zh --mode [streaming|offline] 62 | ``` 63 | > [!TIP] 64 | > Only offline CosyVoice TTS is currently supported. Setting the client to `streaming` simply enables NVIDIA Triton’s decoupled mode so that responses are returned as soon as they are ready. 65 | 66 | ### Benchmark Results 67 | Decoding on a single L20 GPU with 26 prompt_audio/target_text [pairs](https://huggingface.co/datasets/yuekai/seed_tts) (≈221 s of audio): 68 | 69 | | Mode | Note | Concurrency | Avg Latency (ms) | P50 Latency (ms) | RTF | 70 | |------|------|-------------|------------------|------------------|-----| 71 | | Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 758.04 | 615.79 | 0.0891 | 72 | | Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1025.93 | 901.68 | 0.0657 | 73 | | Decoupled=False | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1914.13 | 1783.58 | 0.0610 | 74 | | Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 1 | 659.87 | 655.63 | 0.0891 | 75 | | Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 2 | 1103.16 | 992.96 | 0.0693 | 76 | | Decoupled=True | [Commit](https://github.com/yuekaizhang/CosyVoice/commit/b44f12110224cb11c03aee4084b1597e7b9331cb) | 4 | 1790.91 | 1668.63 | 0.0604 | 77 | 78 | ### OpenAI-Compatible Server 79 | To launch an OpenAI-compatible service, run: 80 | ```sh 81 | git clone https://github.com/yuekaizhang/Triton-OpenAI-Speech.git 82 | pip install -r requirements.txt 83 | # After the Triton service is up, start the FastAPI bridge: 84 | python3 tts_server.py --url http://localhost:8000 --ref_audios_dir ./ref_audios/ --port 10086 --default_sample_rate 24000 85 | # Test with curl 86 | bash test/test_cosyvoice.sh 87 | ``` 88 | 89 | ### Acknowledgements 90 | This section originates from the NVIDIA CISI project. We also provide other multimodal resources—see [mair-hub](https://github.com/nvidia-china-sae/mair-hub) for details. 91 | 92 | -------------------------------------------------------------------------------- /cosyvoice/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/60 9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice-300M 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v1.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # train llm 55 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 56 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 57 | job_id=1986 58 | dist_backend="nccl" 59 | num_workers=2 60 | prefetch=100 61 | train_engine=torch_ddp 62 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 63 | echo "Run train. We only support llm traning for now" 64 | if [ $train_engine == 'deepspeed' ]; then 65 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 66 | fi 67 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list 68 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list 69 | for model in llm flow hifigan; do 70 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 71 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ 72 | cosyvoice/bin/train.py \ 73 | --train_engine $train_engine \ 74 | --config conf/cosyvoice.yaml \ 75 | --train_data data/train.data.list \ 76 | --cv_data data/dev.data.list \ 77 | --model $model \ 78 | --checkpoint $pretrained_model_dir/$model.pt \ 79 | --model_dir `pwd`/exp/cosyvoice/$model/$train_engine \ 80 | --tensorboard_dir `pwd`/tensorboard/cosyvoice/$model/$train_engine \ 81 | --ddp.dist_backend $dist_backend \ 82 | --num_workers ${num_workers} \ 83 | --prefetch ${prefetch} \ 84 | --pin_memory \ 85 | --use_amp \ 86 | --deepspeed_config ./conf/ds_stage2.json \ 87 | --deepspeed.save_states model+optimizer 88 | done 89 | fi 90 | 91 | # average model 92 | average_num=5 93 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 94 | for model in llm flow hifigan; do 95 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 96 | echo "do model average and final checkpoint is $decode_checkpoint" 97 | python cosyvoice/bin/average_model.py \ 98 | --dst_model $decode_checkpoint \ 99 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 100 | --num ${average_num} \ 101 | --val_best 102 | done 103 | fi 104 | 105 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then 106 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 107 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 108 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 109 | fi -------------------------------------------------------------------------------- /runtime/python/grpc/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | from concurrent import futures 17 | import argparse 18 | import cosyvoice_pb2 19 | import cosyvoice_pb2_grpc 20 | import logging 21 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 22 | import grpc 23 | import torch 24 | import numpy as np 25 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 26 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 27 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 28 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 29 | 30 | logging.basicConfig(level=logging.DEBUG, 31 | format='%(asctime)s %(levelname)s %(message)s') 32 | 33 | 34 | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): 35 | def __init__(self, args): 36 | try: 37 | self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc) 38 | except Exception: 39 | try: 40 | self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc) 41 | except Exception: 42 | raise TypeError('no valid model_type!') 43 | logging.info('grpc service initialized') 44 | 45 | def Inference(self, request, context): 46 | if request.HasField('sft_request'): 47 | logging.info('get sft inference request') 48 | model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) 49 | elif request.HasField('zero_shot_request'): 50 | logging.info('get zero_shot inference request') 51 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 52 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 53 | model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, 54 | request.zero_shot_request.prompt_text, 55 | prompt_speech_16k) 56 | elif request.HasField('cross_lingual_request'): 57 | logging.info('get cross_lingual inference request') 58 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 59 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 60 | model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) 61 | else: 62 | logging.info('get instruct inference request') 63 | model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, 64 | request.instruct_request.spk_id, 65 | request.instruct_request.instruct_text) 66 | 67 | logging.info('send inference response') 68 | for i in model_output: 69 | response = cosyvoice_pb2.Response() 70 | response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 71 | yield response 72 | 73 | 74 | def main(): 75 | grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) 76 | cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) 77 | grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) 78 | grpcServer.start() 79 | logging.info("server listening on 0.0.0.0:{}".format(args.port)) 80 | grpcServer.wait_for_termination() 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--port', 86 | type=int, 87 | default=50000) 88 | parser.add_argument('--max_conc', 89 | type=int, 90 | default=4) 91 | parser.add_argument('--model_dir', 92 | type=str, 93 | default='iic/CosyVoice-300M', 94 | help='local path or modelscope repo id') 95 | args = parser.parse_args() 96 | main() 97 | -------------------------------------------------------------------------------- /cosyvoice/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import regex 17 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 18 | 19 | 20 | # whether contain chinese character 21 | def contains_chinese(text): 22 | return bool(chinese_char_pattern.search(text)) 23 | 24 | 25 | # replace special symbol 26 | def replace_corner_mark(text): 27 | text = text.replace('²', '平方') 28 | text = text.replace('³', '立方') 29 | return text 30 | 31 | 32 | # remove meaningless symbol 33 | def remove_bracket(text): 34 | text = text.replace('(', '').replace(')', '') 35 | text = text.replace('【', '').replace('】', '') 36 | text = text.replace('`', '').replace('`', '') 37 | text = text.replace("——", " ") 38 | return text 39 | 40 | 41 | # spell Arabic numerals 42 | def spell_out_number(text: str, inflect_parser): 43 | new_text = [] 44 | st = None 45 | for i, c in enumerate(text): 46 | if not c.isdigit(): 47 | if st is not None: 48 | num_str = inflect_parser.number_to_words(text[st: i]) 49 | new_text.append(num_str) 50 | st = None 51 | new_text.append(c) 52 | else: 53 | if st is None: 54 | st = i 55 | if st is not None and st < len(text): 56 | num_str = inflect_parser.number_to_words(text[st:]) 57 | new_text.append(num_str) 58 | return ''.join(new_text) 59 | 60 | 61 | # split paragrah logic: 62 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 63 | # 2. cal sentence len according to lang 64 | # 3. split sentence according to puncatation 65 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 66 | def calc_utt_length(_text: str): 67 | if lang == "zh": 68 | return len(_text) 69 | else: 70 | return len(tokenize(_text)) 71 | 72 | def should_merge(_text: str): 73 | if lang == "zh": 74 | return len(_text) < merge_len 75 | else: 76 | return len(tokenize(_text)) < merge_len 77 | 78 | if lang == "zh": 79 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 80 | else: 81 | pounc = ['.', '?', '!', ';', ':'] 82 | if comma_split: 83 | pounc.extend([',', ',']) 84 | 85 | if text[-1] not in pounc: 86 | if lang == "zh": 87 | text += "。" 88 | else: 89 | text += "." 90 | 91 | st = 0 92 | utts = [] 93 | for i, c in enumerate(text): 94 | if c in pounc: 95 | if len(text[st: i]) > 0: 96 | utts.append(text[st: i] + c) 97 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 98 | tmp = utts.pop(-1) 99 | utts.append(tmp + text[i + 1]) 100 | st = i + 2 101 | else: 102 | st = i + 1 103 | 104 | final_utts = [] 105 | cur_utt = "" 106 | for utt in utts: 107 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 108 | final_utts.append(cur_utt) 109 | cur_utt = "" 110 | cur_utt = cur_utt + utt 111 | if len(cur_utt) > 0: 112 | if should_merge(cur_utt) and len(final_utts) != 0: 113 | final_utts[-1] = final_utts[-1] + cur_utt 114 | else: 115 | final_utts.append(cur_utt) 116 | 117 | return final_utts 118 | 119 | 120 | # remove blank between chinese character 121 | def replace_blank(text: str): 122 | out_str = [] 123 | for i, c in enumerate(text): 124 | if c == " ": 125 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 126 | (text[i - 1].isascii() and text[i - 1] != " ")): 127 | out_str.append(c) 128 | else: 129 | out_str.append(c) 130 | return "".join(out_str) 131 | 132 | 133 | def is_only_punctuation(text): 134 | # Regular expression: Match strings that consist only of punctuation marks or are empty. 135 | punctuation_pattern = r'^[\p{P}\p{S}]*$' 136 | return bool(regex.fullmatch(punctuation_pattern, text)) 137 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/60 9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 28 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 29 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 30 | tools/extract_embedding.py --dir data/$x \ 31 | --onnx_path $pretrained_model_dir/campplus.onnx 32 | done 33 | fi 34 | 35 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 36 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 37 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 38 | tools/extract_speech_token.py --dir data/$x \ 39 | --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx 40 | done 41 | fi 42 | 43 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 44 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 45 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 46 | mkdir -p data/$x/parquet 47 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 48 | --num_processes 10 \ 49 | --src_dir data/$x \ 50 | --des_dir data/$x/parquet 51 | done 52 | fi 53 | 54 | # train llm 55 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 56 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 57 | job_id=1986 58 | dist_backend="nccl" 59 | num_workers=2 60 | prefetch=100 61 | train_engine=torch_ddp 62 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 63 | echo "Run train. We only support llm traning for now" 64 | if [ $train_engine == 'deepspeed' ]; then 65 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 66 | fi 67 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list 68 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list 69 | # NOTE will update llm/hift training later 70 | for model in llm flow hifigan; do 71 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 72 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ 73 | cosyvoice/bin/train.py \ 74 | --train_engine $train_engine \ 75 | --config conf/cosyvoice2.yaml \ 76 | --train_data data/train.data.list \ 77 | --cv_data data/dev.data.list \ 78 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ 79 | --model $model \ 80 | --checkpoint $pretrained_model_dir/$model.pt \ 81 | --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ 82 | --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \ 83 | --ddp.dist_backend $dist_backend \ 84 | --num_workers ${num_workers} \ 85 | --prefetch ${prefetch} \ 86 | --pin_memory \ 87 | --use_amp \ 88 | --deepspeed_config ./conf/ds_stage2.json \ 89 | --deepspeed.save_states model+optimizer 90 | done 91 | fi 92 | 93 | # average model 94 | average_num=5 95 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 96 | for model in llm flow hifigan; do 97 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 98 | echo "do model average and final checkpoint is $decode_checkpoint" 99 | python cosyvoice/bin/average_model.py \ 100 | --dst_model $decode_checkpoint \ 101 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 102 | --num ${average_num} \ 103 | --val_best 104 | done 105 | fi 106 | 107 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then 108 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 109 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 110 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 111 | fi -------------------------------------------------------------------------------- /runtime/python/grpc/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 18 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 19 | import logging 20 | import argparse 21 | import torchaudio 22 | import cosyvoice_pb2 23 | import cosyvoice_pb2_grpc 24 | import grpc 25 | import torch 26 | import numpy as np 27 | from cosyvoice.utils.file_utils import load_wav 28 | 29 | 30 | def main(): 31 | with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel: 32 | stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel) 33 | request = cosyvoice_pb2.Request() 34 | if args.mode == 'sft': 35 | logging.info('send sft request') 36 | sft_request = cosyvoice_pb2.sftRequest() 37 | sft_request.spk_id = args.spk_id 38 | sft_request.tts_text = args.tts_text 39 | request.sft_request.CopyFrom(sft_request) 40 | elif args.mode == 'zero_shot': 41 | logging.info('send zero_shot request') 42 | zero_shot_request = cosyvoice_pb2.zeroshotRequest() 43 | zero_shot_request.tts_text = args.tts_text 44 | zero_shot_request.prompt_text = args.prompt_text 45 | prompt_speech = load_wav(args.prompt_wav, 16000) 46 | zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 47 | request.zero_shot_request.CopyFrom(zero_shot_request) 48 | elif args.mode == 'cross_lingual': 49 | logging.info('send cross_lingual request') 50 | cross_lingual_request = cosyvoice_pb2.crosslingualRequest() 51 | cross_lingual_request.tts_text = args.tts_text 52 | prompt_speech = load_wav(args.prompt_wav, 16000) 53 | cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 54 | request.cross_lingual_request.CopyFrom(cross_lingual_request) 55 | else: 56 | logging.info('send instruct request') 57 | instruct_request = cosyvoice_pb2.instructRequest() 58 | instruct_request.tts_text = args.tts_text 59 | instruct_request.spk_id = args.spk_id 60 | instruct_request.instruct_text = args.instruct_text 61 | request.instruct_request.CopyFrom(instruct_request) 62 | 63 | response = stub.Inference(request) 64 | tts_audio = b'' 65 | for r in response: 66 | tts_audio += r.tts_audio 67 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 68 | logging.info('save response to {}'.format(args.tts_wav)) 69 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 70 | logging.info('get response') 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--host', 76 | type=str, 77 | default='0.0.0.0') 78 | parser.add_argument('--port', 79 | type=int, 80 | default='50000') 81 | parser.add_argument('--mode', 82 | default='sft', 83 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 84 | help='request mode') 85 | parser.add_argument('--tts_text', 86 | type=str, 87 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 88 | parser.add_argument('--spk_id', 89 | type=str, 90 | default='中文女') 91 | parser.add_argument('--prompt_text', 92 | type=str, 93 | default='希望你以后能够做的比我还好呦。') 94 | parser.add_argument('--prompt_wav', 95 | type=str, 96 | default='../../../asset/zero_shot_prompt.wav') 97 | parser.add_argument('--instruct_text', 98 | type=str, 99 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 100 | Fights with fervor for justice, but struggles with impulsiveness.') 101 | parser.add_argument('--tts_wav', 102 | type=str, 103 | default='demo.wav') 104 | args = parser.parse_args() 105 | prompt_sr, target_sr = 16000, 22050 106 | main() 107 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 31 | from cosyvoice.utils.file_utils import logging 32 | 33 | 34 | def get_dummy_input(batch_size, seq_len, out_channels, device): 35 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 36 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 37 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 38 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 39 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 40 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 41 | return x, mask, mu, t, spks, cond 42 | 43 | 44 | def get_args(): 45 | parser = argparse.ArgumentParser(description='export your model for deployment') 46 | parser.add_argument('--model_dir', 47 | type=str, 48 | default='pretrained_models/CosyVoice-300M', 49 | help='local path') 50 | args = parser.parse_args() 51 | print(args) 52 | return args 53 | 54 | 55 | @torch.no_grad() 56 | def main(): 57 | args = get_args() 58 | logging.basicConfig(level=logging.DEBUG, 59 | format='%(asctime)s %(levelname)s %(message)s') 60 | 61 | try: 62 | model = CosyVoice(args.model_dir) 63 | except Exception: 64 | try: 65 | model = CosyVoice2(args.model_dir) 66 | except Exception: 67 | raise TypeError('no valid model_type!') 68 | 69 | # 1. export flow decoder estimator 70 | estimator = model.model.flow.decoder.estimator 71 | estimator.eval() 72 | 73 | device = model.model.device 74 | batch_size, seq_len = 2, 256 75 | out_channels = model.model.flow.decoder.estimator.out_channels 76 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 77 | torch.onnx.export( 78 | estimator, 79 | (x, mask, mu, t, spks, cond), 80 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 81 | export_params=True, 82 | opset_version=18, 83 | do_constant_folding=True, 84 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 85 | output_names=['estimator_out'], 86 | dynamic_axes={ 87 | 'x': {2: 'seq_len'}, 88 | 'mask': {2: 'seq_len'}, 89 | 'mu': {2: 'seq_len'}, 90 | 'cond': {2: 'seq_len'}, 91 | 'estimator_out': {2: 'seq_len'}, 92 | } 93 | ) 94 | 95 | # 2. test computation consistency 96 | option = onnxruntime.SessionOptions() 97 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 98 | option.intra_op_num_threads = 1 99 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 100 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 101 | sess_options=option, providers=providers) 102 | 103 | for _ in tqdm(range(10)): 104 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) 105 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 106 | ort_inputs = { 107 | 'x': x.cpu().numpy(), 108 | 'mask': mask.cpu().numpy(), 109 | 'mu': mu.cpu().numpy(), 110 | 't': t.cpu().numpy(), 111 | 'spks': spks.cpu().numpy(), 112 | 'cond': cond.cpu().numpy() 113 | } 114 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 115 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 116 | logging.info('successfully export estimator') 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright (c) 2025 NVIDIA (authors: Yuekai Zhang) 3 | export CUDA_VISIBLE_DEVICES=0 4 | cosyvoice_path=/workspace/CosyVoice 5 | export PYTHONPATH=${cosyvoice_path}:$PYTHONPATH 6 | export PYTHONPATH=${cosyvoice_path}/third_party/Matcha-TTS:$PYTHONPATH 7 | stage=$1 8 | stop_stage=$2 9 | 10 | huggingface_model_local_dir=./cosyvoice2_llm 11 | model_scope_model_local_dir=./CosyVoice2-0.5B 12 | trt_dtype=bfloat16 13 | trt_weights_dir=./trt_weights_${trt_dtype} 14 | trt_engines_dir=./trt_engines_${trt_dtype} 15 | 16 | model_repo=./model_repo_cosyvoice2 17 | 18 | if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then 19 | echo "Cloning CosyVoice" 20 | git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git $cosyvoice_path 21 | cd $cosyvoice_path 22 | git submodule update --init --recursive 23 | cd runtime/triton_trtllm 24 | fi 25 | 26 | if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then 27 | echo "Downloading CosyVoice2-0.5B" 28 | huggingface-cli download --local-dir $huggingface_model_local_dir yuekai/cosyvoice2_llm 29 | modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_local_dir 30 | fi 31 | 32 | 33 | if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then 34 | echo "Converting checkpoint to TensorRT weights" 35 | python3 scripts/convert_checkpoint.py --model_dir $huggingface_model_local_dir \ 36 | --output_dir $trt_weights_dir \ 37 | --dtype $trt_dtype || exit 1 38 | 39 | echo "Building TensorRT engines" 40 | trtllm-build --checkpoint_dir $trt_weights_dir \ 41 | --output_dir $trt_engines_dir \ 42 | --max_batch_size 16 \ 43 | --max_num_tokens 32768 \ 44 | --gemm_plugin $trt_dtype || exit 1 45 | 46 | echo "Testing TensorRT engines" 47 | python3 ./scripts/test_llm.py --input_text "你好,请问你叫什么?" \ 48 | --tokenizer_dir $huggingface_model_local_dir \ 49 | --top_k 50 --top_p 0.95 --temperature 0.8 \ 50 | --engine_dir=$trt_engines_dir || exit 1 51 | fi 52 | 53 | if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then 54 | echo "Creating model repository" 55 | rm -rf $model_repo 56 | mkdir -p $model_repo 57 | cosyvoice2_dir="cosyvoice2" 58 | 59 | cp -r ./model_repo/${cosyvoice2_dir} $model_repo 60 | cp -r ./model_repo/audio_tokenizer $model_repo 61 | cp -r ./model_repo/tensorrt_llm $model_repo 62 | cp -r ./model_repo/token2wav $model_repo 63 | 64 | ENGINE_PATH=$trt_engines_dir 65 | MAX_QUEUE_DELAY_MICROSECONDS=0 66 | MODEL_DIR=$model_scope_model_local_dir 67 | LLM_TOKENIZER_DIR=$huggingface_model_local_dir 68 | BLS_INSTANCE_NUM=4 69 | TRITON_MAX_BATCH_SIZE=16 70 | DECOUPLED_MODE=False 71 | 72 | python3 scripts/fill_template.py -i ${model_repo}/token2wav/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} 73 | python3 scripts/fill_template.py -i ${model_repo}/audio_tokenizer/config.pbtxt model_dir:${MODEL_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} 74 | python3 scripts/fill_template.py -i ${model_repo}/${cosyvoice2_dir}/config.pbtxt model_dir:${MODEL_DIR},bls_instance_num:${BLS_INSTANCE_NUM},llm_tokenizer_dir:${LLM_TOKENIZER_DIR},triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS} 75 | python3 scripts/fill_template.py -i ${model_repo}/tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,triton_max_batch_size:${TRITON_MAX_BATCH_SIZE},decoupled_mode:${DECOUPLED_MODE},max_beam_width:1,engine_dir:${ENGINE_PATH},max_tokens_in_paged_kv_cache:2560,max_attention_window_size:2560,kv_cache_free_gpu_mem_fraction:0.5,exclude_input_in_output:True,enable_kv_cache_reuse:False,batching_strategy:inflight_fused_batching,max_queue_delay_microseconds:${MAX_QUEUE_DELAY_MICROSECONDS},encoder_input_features_data_type:TYPE_FP16,logits_datatype:TYPE_FP32 76 | 77 | fi 78 | 79 | if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then 80 | echo "Starting Triton server" 81 | tritonserver --model-repository $model_repo 82 | fi 83 | 84 | if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then 85 | echo "Single request test http" 86 | python3 client_http.py \ 87 | --reference-audio ./assets/prompt_audio.wav \ 88 | --reference-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \ 89 | --target-text "身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。" \ 90 | --model-name cosyvoice2 91 | fi 92 | 93 | if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then 94 | echo "Running benchmark client grpc" 95 | num_task=4 96 | # set mode=streaming, when decoupled=True 97 | # set mode=offline, when decoupled=False 98 | mode=offline 99 | python3 client_grpc.py \ 100 | --server-addr localhost \ 101 | --model-name cosyvoice2 \ 102 | --num-tasks $num_task \ 103 | --mode $mode \ 104 | --huggingface-dataset yuekai/seed_tts_cosy2 \ 105 | --log-dir ./log_concurrent_tasks_${num_task}_${mode}_bls_4_${trt_dtype} 106 | fi -------------------------------------------------------------------------------- /cosyvoice/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /cosyvoice/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import math 18 | from functools import partial 19 | 20 | import torch 21 | import torch.distributed as dist 22 | from torch.utils.data import IterableDataset 23 | from cosyvoice.utils.file_utils import read_lists 24 | 25 | 26 | class Processor(IterableDataset): 27 | 28 | def __init__(self, source, f, *args, **kw): 29 | assert callable(f) 30 | self.source = source 31 | self.f = f 32 | self.args = args 33 | self.kw = kw 34 | 35 | def set_epoch(self, epoch): 36 | self.source.set_epoch(epoch) 37 | 38 | def __iter__(self): 39 | """ Return an iterator over the source dataset processed by the 40 | given processor. 41 | """ 42 | assert self.source is not None 43 | assert callable(self.f) 44 | return self.f(iter(self.source), *self.args, **self.kw) 45 | 46 | def apply(self, f): 47 | assert callable(f) 48 | return Processor(self, f, *self.args, **self.kw) 49 | 50 | 51 | class DistributedSampler: 52 | 53 | def __init__(self, shuffle=True, partition=True): 54 | self.epoch = -1 55 | self.update() 56 | self.shuffle = shuffle 57 | self.partition = partition 58 | 59 | def update(self): 60 | assert dist.is_available() 61 | if dist.is_initialized(): 62 | self.rank = dist.get_rank() 63 | self.world_size = dist.get_world_size() 64 | else: 65 | self.rank = 0 66 | self.world_size = 1 67 | worker_info = torch.utils.data.get_worker_info() 68 | if worker_info is None: 69 | self.worker_id = 0 70 | self.num_workers = 1 71 | else: 72 | self.worker_id = worker_info.id 73 | self.num_workers = worker_info.num_workers 74 | return dict(rank=self.rank, 75 | world_size=self.world_size, 76 | worker_id=self.worker_id, 77 | num_workers=self.num_workers) 78 | 79 | def set_epoch(self, epoch): 80 | self.epoch = epoch 81 | 82 | def sample(self, data): 83 | """ Sample data according to rank/world_size/num_workers 84 | 85 | Args: 86 | data(List): input data list 87 | 88 | Returns: 89 | List: data list after sample 90 | """ 91 | data = list(range(len(data))) 92 | # force datalist even 93 | if self.partition: 94 | if self.shuffle: 95 | random.Random(self.epoch).shuffle(data) 96 | if len(data) < self.world_size: 97 | data = data * math.ceil(self.world_size / len(data)) 98 | data = data[:self.world_size] 99 | data = data[self.rank::self.world_size] 100 | if len(data) < self.num_workers: 101 | data = data * math.ceil(self.num_workers / len(data)) 102 | data = data[:self.num_workers] 103 | data = data[self.worker_id::self.num_workers] 104 | return data 105 | 106 | 107 | class DataList(IterableDataset): 108 | 109 | def __init__(self, lists, shuffle=True, partition=True): 110 | self.lists = lists 111 | self.sampler = DistributedSampler(shuffle, partition) 112 | 113 | def set_epoch(self, epoch): 114 | self.sampler.set_epoch(epoch) 115 | 116 | def __iter__(self): 117 | sampler_info = self.sampler.update() 118 | indexes = self.sampler.sample(self.lists) 119 | for index in indexes: 120 | data = dict(src=self.lists[index]) 121 | data.update(sampler_info) 122 | yield data 123 | 124 | 125 | def Dataset(data_list_file, 126 | data_pipeline, 127 | mode='train', 128 | gan=False, 129 | dpo=False, 130 | shuffle=True, 131 | partition=True): 132 | """ Construct dataset from arguments 133 | 134 | We have two shuffle stage in the Dataset. The first is global 135 | shuffle at shards tar/raw file level. The second is global shuffle 136 | at training samples level. 137 | 138 | Args: 139 | data_type(str): raw/shard 140 | tokenizer (BaseTokenizer): tokenizer to tokenize 141 | partition(bool): whether to do data partition in terms of rank 142 | """ 143 | lists = read_lists(data_list_file) 144 | dataset = DataList(lists, 145 | shuffle=shuffle, 146 | partition=partition) 147 | # map partial arg to padding func 148 | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo) 149 | for func in data_pipeline: 150 | dataset = Processor(dataset, func, mode=mode) 151 | return dataset 152 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/run_dpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2024 Alibaba Inc. All Rights Reserved. 3 | . ./path.sh || exit 1; 4 | 5 | stage=-1 6 | stop_stage=3 7 | 8 | data_url=www.openslr.org/resources/60 9 | data_dir=/mnt/lyuxiang.lx/data/tts/openslr/libritts 10 | pretrained_model_dir=../../../pretrained_models/CosyVoice2-0.5B 11 | 12 | if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then 13 | echo "Data Download" 14 | for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do 15 | local/download_and_untar.sh ${data_dir} ${data_url} ${part} 16 | done 17 | fi 18 | 19 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 20 | echo "Data preparation, prepare wav.scp/text/utt2spk/spk2utt" 21 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 22 | mkdir -p data/$x 23 | python local/prepare_data.py --src_dir $data_dir/LibriTTS/$x --des_dir data/$x 24 | done 25 | fi 26 | 27 | if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then 28 | echo "Prepare negative samples using CosyVoice2-0.5B, this is also our reference model. 29 | Here we use CosyVoice2-0.5B generated audio as reject sample for simplicity, you can use metric like wer/similarity." 30 | for x in train-clean-100 train-clean-360 train-other-500; do 31 | mkdir -p data/${x}_reject 32 | python local/prepare_reject_sample.py --src_dir data/$x --des_dir data/${x}_reject --ref_model $pretrained_model_dir 33 | done 34 | fi 35 | 36 | if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then 37 | echo "Extract campplus speaker embedding, you will get spk2embedding.pt and utt2embedding.pt in data/$x dir" 38 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 39 | tools/extract_embedding.py --dir data/$x \ 40 | --onnx_path $pretrained_model_dir/campplus.onnx 41 | done 42 | fi 43 | 44 | if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then 45 | echo "Extract discrete speech token, you will get utt2speech_token.pt in data/$x dir" 46 | for x in train-clean-100 train-clean-360 train-other-500 train-clean-100_reject train-clean-360_reject dev-clean dev-other test-clean test-other; do 47 | tools/extract_speech_token.py --dir data/$x \ 48 | --onnx_path $pretrained_model_dir/speech_tokenizer_v2.onnx 49 | done 50 | fi 51 | 52 | if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then 53 | echo "Prepare required parquet format data, you should have prepared wav.scp/text/utt2spk/spk2utt/utt2embedding.pt/spk2embedding.pt/utt2speech_token.pt" 54 | for x in train-clean-100 train-clean-360 train-other-500 dev-clean dev-other test-clean test-other; do 55 | mkdir -p data/$x/parquet 56 | tools/make_parquet_list.py --num_utts_per_parquet 1000 \ 57 | --num_processes 10 \ 58 | --dpo \ 59 | --src_dir data/$x \ 60 | --des_dir data/$x/parquet 61 | done 62 | fi 63 | 64 | # train llm 65 | export CUDA_VISIBLE_DEVICES="0,1,2,3" 66 | num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') 67 | job_id=1986 68 | dist_backend="nccl" 69 | num_workers=2 70 | prefetch=100 71 | train_engine=torch_ddp 72 | if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then 73 | echo "Run train. We only support llm traning for now" 74 | if [ $train_engine == 'deepspeed' ]; then 75 | echo "Notice deepspeed has its own optimizer config. Modify conf/ds_stage2.json if necessary" 76 | fi 77 | cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list 78 | cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list 79 | # NOTE only llm supports dpo 80 | for model in llm; do 81 | torchrun --nnodes=1 --nproc_per_node=$num_gpus \ 82 | --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:1234" \ 83 | cosyvoice/bin/train.py \ 84 | --train_engine $train_engine \ 85 | --config conf/cosyvoice2.yaml \ 86 | --train_data data/train.data.list \ 87 | --cv_data data/dev.data.list \ 88 | --qwen_pretrain_path $pretrained_model_dir/CosyVoice-BlankEN \ 89 | --model $model \ 90 | --checkpoint $pretrained_model_dir/$model.pt \ 91 | --ref_model $pretrained_model_dir/llm.pt \ 92 | --model_dir `pwd`/exp/cosyvoice2/$model/$train_engine \ 93 | --tensorboard_dir `pwd`/tensorboard/cosyvoice2/$model/$train_engine \ 94 | --ddp.dist_backend $dist_backend \ 95 | --num_workers ${num_workers} \ 96 | --prefetch ${prefetch} \ 97 | --pin_memory \ 98 | --use_amp \ 99 | --dpo \ 100 | --deepspeed_config ./conf/ds_stage2.json \ 101 | --deepspeed.save_states model+optimizer 102 | done 103 | fi 104 | 105 | # average model 106 | average_num=5 107 | if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then 108 | for model in llm flow hifigan; do 109 | decode_checkpoint=`pwd`/exp/cosyvoice/$model/$train_engine/${model}.pt 110 | echo "do model average and final checkpoint is $decode_checkpoint" 111 | python cosyvoice/bin/average_model.py \ 112 | --dst_model $decode_checkpoint \ 113 | --src_path `pwd`/exp/cosyvoice/$model/$train_engine \ 114 | --num ${average_num} \ 115 | --val_best 116 | done 117 | fi 118 | 119 | if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then 120 | echo "Export your model for inference speedup. Remember copy your llm or flow model to model_dir" 121 | python cosyvoice/bin/export_jit.py --model_dir $pretrained_model_dir 122 | python cosyvoice/bin/export_onnx.py --model_dir $pretrained_model_dir 123 | fi -------------------------------------------------------------------------------- /tools/make_parquet_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import os 18 | import json 19 | from tqdm import tqdm 20 | import pandas as pd 21 | import multiprocessing 22 | import time 23 | import torch 24 | 25 | 26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): 27 | start_time = time.time() 28 | data_list = [] 29 | for utt in tqdm(utt_list): 30 | data = open(utt2wav[utt], 'rb').read() 31 | data_list.append(data) 32 | wav_list = [utt2wav[utt] for utt in utt_list] 33 | text_list = [utt2text[utt] for utt in utt_list] 34 | spk_list = [utt2spk[utt] for utt in utt_list] 35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list] 36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] 37 | speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list] 38 | if args.dpo: 39 | reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] 40 | 41 | # 保存到parquet,utt2parquet_file,spk2parquet_file 42 | df = pd.DataFrame() 43 | df['utt'] = utt_list 44 | df['wav'] = wav_list 45 | df['audio_data'] = data_list 46 | df['text'] = text_list 47 | df['spk'] = spk_list 48 | df['utt_embedding'] = uttembedding_list 49 | df['spk_embedding'] = spkembedding_list 50 | df['speech_token'] = speech_token_list 51 | if args.dpo: 52 | df['reject_speech_token'] = reject_speech_token_list 53 | df.to_parquet(parquet_file) 54 | with open(utt2parquet_file, 'w') as f: 55 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) 56 | with open(spk2parquet_file, 'w') as f: 57 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) 58 | logging.info('spend time {}'.format(time.time() - start_time)) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--num_utts_per_parquet', 64 | type=int, 65 | default=1000, 66 | help='num utts per parquet') 67 | parser.add_argument('--num_processes', 68 | type=int, 69 | default=1, 70 | help='num processes for make parquets') 71 | parser.add_argument('--src_dir', 72 | type=str) 73 | parser.add_argument('--des_dir', 74 | type=str) 75 | parser.add_argument('--dpo', 76 | action='store_true', 77 | default=False, 78 | help='Use Direct Preference Optimization') 79 | args = parser.parse_args() 80 | 81 | utt2wav, utt2text, utt2spk = {}, {}, {} 82 | with open('{}/wav.scp'.format(args.src_dir)) as f: 83 | for l in f: 84 | l = l.replace('\n', '').split() 85 | utt2wav[l[0]] = l[1] 86 | with open('{}/text'.format(args.src_dir)) as f: 87 | for l in f: 88 | l = l.replace('\n', '').split() 89 | utt2text[l[0]] = ' '.join(l[1:]) 90 | with open('{}/utt2spk'.format(args.src_dir)) as f: 91 | for l in f: 92 | l = l.replace('\n', '').split() 93 | utt2spk[l[0]] = l[1] 94 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) 95 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) 96 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) 97 | if args.dpo: 98 | utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) 99 | utts = list(utt2wav.keys()) 100 | 101 | # Using process pool to speedup 102 | pool = multiprocessing.Pool(processes=args.num_processes) 103 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] 104 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): 105 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) 106 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) 107 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) 108 | parquet_list.append(parquet_file) 109 | utt2parquet_list.append(utt2parquet_file) 110 | spk2parquet_list.append(spk2parquet_file) 111 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) 112 | pool.close() 113 | pool.join() 114 | 115 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ 116 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ 117 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: 118 | for name in parquet_list: 119 | f1.write(name + '\n') 120 | for name in utt2parquet_list: 121 | f2.write(name + '\n') 122 | for name in spk2parquet_list: 123 | f3.write(name + '\n') 124 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/scripts/test_llm.py: -------------------------------------------------------------------------------- 1 | 2 | # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import argparse 18 | import ast 19 | import csv 20 | import os 21 | from pathlib import Path 22 | from typing import List, Optional 23 | 24 | import numpy as np 25 | import torch 26 | 27 | import tensorrt_llm 28 | from tensorrt_llm.logger import logger 29 | 30 | from tensorrt_llm.runtime import ModelRunnerCpp 31 | from transformers import AutoTokenizer 32 | 33 | 34 | def parse_arguments(args=None): 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument( 37 | '--input_text', 38 | type=str, 39 | nargs='+', 40 | default=["Born in north-east France, Soyer trained as a"]) 41 | parser.add_argument('--tokenizer_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") 42 | parser.add_argument('--engine_dir', type=str, default="meta-llama/Meta-Llama-3-8B-Instruct") 43 | parser.add_argument('--log_level', type=str, default="debug") 44 | parser.add_argument('--kv_cache_free_gpu_memory_fraction', type=float, default=0.6) 45 | parser.add_argument('--temperature', type=float, default=0.8) 46 | parser.add_argument('--top_k', type=int, default=50) 47 | parser.add_argument('--top_p', type=float, default=0.95) 48 | 49 | return parser.parse_args(args=args) 50 | 51 | 52 | def parse_input(tokenizer, 53 | input_text=None, 54 | prompt_template=None): 55 | batch_input_ids = [] 56 | for curr_text in input_text: 57 | if prompt_template is not None: 58 | curr_text = prompt_template.format(input_text=curr_text) 59 | input_ids = tokenizer.encode( 60 | curr_text) 61 | batch_input_ids.append(input_ids) 62 | 63 | batch_input_ids = [ 64 | torch.tensor(x, dtype=torch.int32) for x in batch_input_ids 65 | ] 66 | 67 | logger.debug(f"Input token ids (batch_size = {len(batch_input_ids)}):") 68 | for i, input_ids in enumerate(batch_input_ids): 69 | logger.debug(f"Request {i}: {input_ids.tolist()}") 70 | 71 | return batch_input_ids 72 | 73 | 74 | def main(args): 75 | runtime_rank = tensorrt_llm.mpi_rank() 76 | logger.set_level(args.log_level) 77 | 78 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_dir) 79 | prompt_template = "<|sos|>{input_text}<|task_id|>" 80 | end_id = tokenizer.convert_tokens_to_ids("<|eos1|>") 81 | 82 | batch_input_ids = parse_input(tokenizer=tokenizer, 83 | input_text=args.input_text, 84 | prompt_template=prompt_template) 85 | 86 | input_lengths = [x.size(0) for x in batch_input_ids] 87 | 88 | runner_kwargs = dict( 89 | engine_dir=args.engine_dir, 90 | rank=runtime_rank, 91 | max_output_len=1024, 92 | enable_context_fmha_fp32_acc=False, 93 | max_batch_size=len(batch_input_ids), 94 | max_input_len=max(input_lengths), 95 | kv_cache_free_gpu_memory_fraction=args.kv_cache_free_gpu_memory_fraction, 96 | cuda_graph_mode=False, 97 | gather_generation_logits=False, 98 | ) 99 | 100 | runner = ModelRunnerCpp.from_dir(**runner_kwargs) 101 | 102 | with torch.no_grad(): 103 | outputs = runner.generate( 104 | batch_input_ids=batch_input_ids, 105 | max_new_tokens=1024, 106 | end_id=end_id, 107 | pad_id=end_id, 108 | temperature=args.temperature, 109 | top_k=args.top_k, 110 | top_p=args.top_p, 111 | num_return_sequences=1, 112 | repetition_penalty=1.1, 113 | random_seed=42, 114 | streaming=False, 115 | output_sequence_lengths=True, 116 | output_generation_logits=False, 117 | return_dict=True, 118 | return_all_generated_tokens=False) 119 | torch.cuda.synchronize() 120 | output_ids, sequence_lengths = outputs["output_ids"], outputs["sequence_lengths"] 121 | num_output_sents, num_beams, _ = output_ids.size() 122 | assert num_beams == 1 123 | beam = 0 124 | batch_size = len(input_lengths) 125 | num_return_sequences = num_output_sents // batch_size 126 | assert num_return_sequences == 1 127 | for i in range(batch_size * num_return_sequences): 128 | batch_idx = i // num_return_sequences 129 | seq_idx = i % num_return_sequences 130 | inputs = output_ids[i][0][:input_lengths[batch_idx]].tolist() 131 | input_text = tokenizer.decode(inputs) 132 | print(f'Input [Text {batch_idx}]: \"{input_text}\"') 133 | output_begin = input_lengths[batch_idx] 134 | output_end = sequence_lengths[i][beam] 135 | outputs = output_ids[i][beam][output_begin:output_end].tolist() 136 | output_text = tokenizer.decode(outputs) 137 | print(f'Output [Text {batch_idx}]: \"{output_text}\"') 138 | logger.debug(str(outputs)) 139 | 140 | 141 | if __name__ == '__main__': 142 | args = parse_arguments() 143 | main(args) 144 | -------------------------------------------------------------------------------- /cosyvoice/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /cosyvoice/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu) 3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import json 19 | import torch 20 | import torchaudio 21 | import logging 22 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 23 | logging.basicConfig(level=logging.DEBUG, 24 | format='%(asctime)s %(levelname)s %(message)s') 25 | 26 | 27 | def read_lists(list_file): 28 | lists = [] 29 | with open(list_file, 'r', encoding='utf8') as fin: 30 | for line in fin: 31 | lists.append(line.strip()) 32 | return lists 33 | 34 | 35 | def read_json_lists(list_file): 36 | lists = read_lists(list_file) 37 | results = {} 38 | for fn in lists: 39 | with open(fn, 'r', encoding='utf8') as fin: 40 | results.update(json.load(fin)) 41 | return results 42 | 43 | 44 | def load_wav(wav, target_sr): 45 | speech, sample_rate = torchaudio.load(wav, backend='soundfile') 46 | speech = speech.mean(dim=0, keepdim=True) 47 | if sample_rate != target_sr: 48 | assert sample_rate > target_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr) 49 | speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech) 50 | return speech 51 | 52 | 53 | def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16): 54 | import tensorrt as trt 55 | logging.info("Converting onnx to trt...") 56 | network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) 57 | logger = trt.Logger(trt.Logger.INFO) 58 | builder = trt.Builder(logger) 59 | network = builder.create_network(network_flags) 60 | parser = trt.OnnxParser(network, logger) 61 | config = builder.create_builder_config() 62 | config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB 63 | if fp16: 64 | config.set_flag(trt.BuilderFlag.FP16) 65 | profile = builder.create_optimization_profile() 66 | # load onnx model 67 | with open(onnx_model, "rb") as f: 68 | if not parser.parse(f.read()): 69 | for error in range(parser.num_errors): 70 | print(parser.get_error(error)) 71 | raise ValueError('failed to parse {}'.format(onnx_model)) 72 | # set input shapes 73 | for i in range(len(trt_kwargs['input_names'])): 74 | profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i]) 75 | tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT 76 | # set input and output data type 77 | for i in range(network.num_inputs): 78 | input_tensor = network.get_input(i) 79 | input_tensor.dtype = tensor_dtype 80 | for i in range(network.num_outputs): 81 | output_tensor = network.get_output(i) 82 | output_tensor.dtype = tensor_dtype 83 | config.add_optimization_profile(profile) 84 | engine_bytes = builder.build_serialized_network(network, config) 85 | # save trt engine 86 | with open(trt_model, "wb") as f: 87 | f.write(engine_bytes) 88 | logging.info("Succesfully convert onnx to trt...") 89 | 90 | 91 | def export_cosyvoice2_vllm(model, model_path, device): 92 | if os.path.exists(model_path): 93 | return 94 | pad_to = DEFAULT_VOCAB_PADDING_SIZE = 64 95 | vocab_size = model.speech_embedding.num_embeddings 96 | feature_size = model.speech_embedding.embedding_dim 97 | pad_vocab_size = ((vocab_size + pad_to - 1) // pad_to) * pad_to 98 | 99 | dtype = torch.bfloat16 100 | # lm_head 101 | new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size, bias=True) 102 | with torch.no_grad(): 103 | new_lm_head.weight[:vocab_size] = model.llm_decoder.weight 104 | new_lm_head.bias[:vocab_size] = model.llm_decoder.bias 105 | new_lm_head.weight[vocab_size:] = 0 106 | new_lm_head.bias[vocab_size:] = 0 107 | model.llm.model.lm_head = new_lm_head 108 | new_codec_embed = torch.nn.Linear(in_features=feature_size, out_features=pad_vocab_size) 109 | # embed_tokens 110 | embed_tokens = model.llm.model.model.embed_tokens 111 | with torch.no_grad(): 112 | new_codec_embed.weight[:vocab_size] = model.speech_embedding.weight 113 | new_codec_embed.weight[vocab_size:] = 0 114 | model.llm.model.set_input_embeddings(new_codec_embed) 115 | model.llm.model.to(device) 116 | model.llm.model.to(dtype) 117 | tmp_vocab_size = model.llm.model.config.vocab_size 118 | tmp_tie_embedding = model.llm.model.config.tie_word_embeddings 119 | del model.llm.model.generation_config.eos_token_id 120 | del model.llm.model.config.bos_token_id 121 | del model.llm.model.config.eos_token_id 122 | model.llm.model.config.vocab_size = pad_vocab_size 123 | model.llm.model.config.tie_word_embeddings = False 124 | model.llm.model.config.use_bias = True 125 | model.llm.model.save_pretrained(model_path) 126 | os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path))) 127 | model.llm.model.config.vocab_size = tmp_vocab_size 128 | model.llm.model.config.tie_word_embeddings = tmp_tie_embedding 129 | model.llm.model.set_input_embeddings(embed_tokens) 130 | -------------------------------------------------------------------------------- /runtime/triton_trtllm/client_http.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # Redistribution and use in source and binary forms, with or without 4 | # modification, are permitted provided that the following conditions 5 | # are met: 6 | # * Redistributions of source code must retain the above copyright 7 | # notice, this list of conditions and the following disclaimer. 8 | # * Redistributions in binary form must reproduce the above copyright 9 | # notice, this list of conditions and the following disclaimer in the 10 | # documentation and/or other materials provided with the distribution. 11 | # * Neither the name of NVIDIA CORPORATION nor the names of its 12 | # contributors may be used to endorse or promote products derived 13 | # from this software without specific prior written permission. 14 | # 15 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY 16 | # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 17 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 18 | # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR 19 | # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 20 | # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 21 | # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 22 | # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY 23 | # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 24 | # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | import requests 27 | import soundfile as sf 28 | import json 29 | import numpy as np 30 | import argparse 31 | 32 | 33 | def get_args(): 34 | parser = argparse.ArgumentParser( 35 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 36 | ) 37 | 38 | parser.add_argument( 39 | "--server-url", 40 | type=str, 41 | default="localhost:8000", 42 | help="Address of the server", 43 | ) 44 | 45 | parser.add_argument( 46 | "--reference-audio", 47 | type=str, 48 | default="../../example/prompt_audio.wav", 49 | help="Path to a single audio file. It can't be specified at the same time with --manifest-dir", 50 | ) 51 | 52 | parser.add_argument( 53 | "--reference-text", 54 | type=str, 55 | default="吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。", 56 | help="", 57 | ) 58 | 59 | parser.add_argument( 60 | "--target-text", 61 | type=str, 62 | default="身临其境,换新体验。塑造开源语音合成新范式,让智能语音更自然。", 63 | help="", 64 | ) 65 | 66 | parser.add_argument( 67 | "--model-name", 68 | type=str, 69 | default="spark_tts", 70 | choices=[ 71 | "f5_tts", 72 | "spark_tts", 73 | "cosyvoice2"], 74 | help="triton model_repo module name to request", 75 | ) 76 | 77 | parser.add_argument( 78 | "--output-audio", 79 | type=str, 80 | default="output.wav", 81 | help="Path to save the output audio", 82 | ) 83 | return parser.parse_args() 84 | 85 | 86 | def prepare_request( 87 | waveform, 88 | reference_text, 89 | target_text, 90 | sample_rate=16000, 91 | padding_duration: int = None, 92 | audio_save_dir: str = "./", 93 | ): 94 | assert len(waveform.shape) == 1, "waveform should be 1D" 95 | lengths = np.array([[len(waveform)]], dtype=np.int32) 96 | if padding_duration: 97 | # padding to nearset 10 seconds 98 | samples = np.zeros( 99 | ( 100 | 1, 101 | padding_duration 102 | * sample_rate 103 | * ((int(len(waveform) / sample_rate) // padding_duration) + 1), 104 | ), 105 | dtype=np.float32, 106 | ) 107 | 108 | samples[0, : len(waveform)] = waveform 109 | else: 110 | samples = waveform 111 | 112 | samples = samples.reshape(1, -1).astype(np.float32) 113 | 114 | data = { 115 | "inputs": [ 116 | { 117 | "name": "reference_wav", 118 | "shape": samples.shape, 119 | "datatype": "FP32", 120 | "data": samples.tolist() 121 | }, 122 | { 123 | "name": "reference_wav_len", 124 | "shape": lengths.shape, 125 | "datatype": "INT32", 126 | "data": lengths.tolist(), 127 | }, 128 | { 129 | "name": "reference_text", 130 | "shape": [1, 1], 131 | "datatype": "BYTES", 132 | "data": [reference_text] 133 | }, 134 | { 135 | "name": "target_text", 136 | "shape": [1, 1], 137 | "datatype": "BYTES", 138 | "data": [target_text] 139 | } 140 | ] 141 | } 142 | 143 | return data 144 | 145 | 146 | if __name__ == "__main__": 147 | args = get_args() 148 | server_url = args.server_url 149 | if not server_url.startswith(("http://", "https://")): 150 | server_url = f"http://{server_url}" 151 | 152 | url = f"{server_url}/v2/models/{args.model_name}/infer" 153 | waveform, sr = sf.read(args.reference_audio) 154 | assert sr == 16000, "sample rate hardcoded in server" 155 | 156 | samples = np.array(waveform, dtype=np.float32) 157 | data = prepare_request(samples, args.reference_text, args.target_text) 158 | 159 | rsp = requests.post( 160 | url, 161 | headers={"Content-Type": "application/json"}, 162 | json=data, 163 | verify=False, 164 | params={"request_id": '0'} 165 | ) 166 | result = rsp.json() 167 | audio = result["outputs"][0]["data"] 168 | audio = np.array(audio, dtype=np.float32) 169 | if args.model_name == "spark_tts": 170 | sample_rate = 16000 171 | else: 172 | sample_rate = 24000 173 | sf.write(args.output_audio, audio, sample_rate, "PCM_16") 174 | -------------------------------------------------------------------------------- /cosyvoice/bin/inference_deprecated.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import torch 22 | from torch.utils.data import DataLoader 23 | import torchaudio 24 | from hyperpyyaml import load_hyperpyyaml 25 | from tqdm import tqdm 26 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 27 | from cosyvoice.dataset.dataset import Dataset 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='inference with your model') 32 | parser.add_argument('--config', required=True, help='config file') 33 | parser.add_argument('--prompt_data', required=True, help='prompt data file') 34 | parser.add_argument('--prompt_utt2data', required=True, help='prompt data file') 35 | parser.add_argument('--tts_text', required=True, help='tts input file') 36 | parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path') 37 | parser.add_argument('--llm_model', required=True, help='llm model file') 38 | parser.add_argument('--flow_model', required=True, help='flow model file') 39 | parser.add_argument('--hifigan_model', required=True, help='hifigan model file') 40 | parser.add_argument('--gpu', 41 | type=int, 42 | default=-1, 43 | help='gpu id for this rank, -1 for cpu') 44 | parser.add_argument('--mode', 45 | default='sft', 46 | choices=['sft', 'zero_shot'], 47 | help='inference mode') 48 | parser.add_argument('--result_dir', required=True, help='asr result file') 49 | args = parser.parse_args() 50 | print(args) 51 | return args 52 | 53 | 54 | def main(): 55 | args = get_args() 56 | logging.basicConfig(level=logging.DEBUG, 57 | format='%(asctime)s %(levelname)s %(message)s') 58 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) 59 | 60 | # Init cosyvoice models from configs 61 | use_cuda = args.gpu >= 0 and torch.cuda.is_available() 62 | device = torch.device('cuda' if use_cuda else 'cpu') 63 | try: 64 | with open(args.config, 'r') as f: 65 | configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': args.qwen_pretrain_path}) 66 | model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift']) 67 | except Exception: 68 | try: 69 | with open(args.config, 'r') as f: 70 | configs = load_hyperpyyaml(f) 71 | model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift']) 72 | except Exception: 73 | raise TypeError('no valid model_type!') 74 | 75 | model.load(args.llm_model, args.flow_model, args.hifigan_model) 76 | 77 | test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, 78 | tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data) 79 | test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0) 80 | 81 | sample_rate = configs['sample_rate'] 82 | del configs 83 | os.makedirs(args.result_dir, exist_ok=True) 84 | fn = os.path.join(args.result_dir, 'wav.scp') 85 | f = open(fn, 'w') 86 | with torch.no_grad(): 87 | for _, batch in tqdm(enumerate(test_data_loader)): 88 | utts = batch["utts"] 89 | assert len(utts) == 1, "inference mode only support batchsize 1" 90 | text_token = batch["text_token"].to(device) 91 | text_token_len = batch["text_token_len"].to(device) 92 | tts_index = batch["tts_index"] 93 | tts_text_token = batch["tts_text_token"].to(device) 94 | tts_text_token_len = batch["tts_text_token_len"].to(device) 95 | speech_token = batch["speech_token"].to(device) 96 | speech_token_len = batch["speech_token_len"].to(device) 97 | speech_feat = batch["speech_feat"].to(device) 98 | speech_feat_len = batch["speech_feat_len"].to(device) 99 | utt_embedding = batch["utt_embedding"].to(device) 100 | spk_embedding = batch["spk_embedding"].to(device) 101 | if args.mode == 'sft': 102 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 103 | 'llm_embedding': spk_embedding, 'flow_embedding': spk_embedding} 104 | else: 105 | model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 106 | 'prompt_text': text_token, 'prompt_text_len': text_token_len, 107 | 'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len, 108 | 'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len, 109 | 'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len, 110 | 'llm_embedding': utt_embedding, 'flow_embedding': utt_embedding} 111 | tts_speeches = [] 112 | for model_output in model.tts(**model_input): 113 | tts_speeches.append(model_output['tts_speech']) 114 | tts_speeches = torch.concat(tts_speeches, dim=1) 115 | tts_key = '{}_{}'.format(utts[0], tts_index[0]) 116 | tts_fn = os.path.join(args.result_dir, '{}.wav'.format(tts_key)) 117 | torchaudio.save(tts_fn, tts_speeches, sample_rate=sample_rate, backend='soundfile') 118 | f.write('{} {}\n'.format(tts_key, tts_fn)) 119 | f.flush() 120 | f.close() 121 | logging.info('Result wav.scp saved in {}'.format(fn)) 122 | 123 | 124 | if __name__ == '__main__': 125 | logging.warning('this code has been deprecated, please refer to README for CosyVoice inference usage!') 126 | main() 127 | -------------------------------------------------------------------------------- /cosyvoice/utils/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou) 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # Modified from ESPnet(https://github.com/espnet/espnet) 17 | """Unility functions for Transformer.""" 18 | 19 | import queue 20 | import random 21 | from typing import List 22 | 23 | import numpy as np 24 | import torch 25 | 26 | IGNORE_ID = -1 27 | 28 | 29 | def pad_list(xs: List[torch.Tensor], pad_value: int): 30 | """Perform padding for the list of tensors. 31 | 32 | Args: 33 | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. 34 | pad_value (float): Value for padding. 35 | 36 | Returns: 37 | Tensor: Padded tensor (B, Tmax, `*`). 38 | 39 | Examples: 40 | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] 41 | >>> x 42 | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] 43 | >>> pad_list(x, 0) 44 | tensor([[1., 1., 1., 1.], 45 | [1., 1., 0., 0.], 46 | [1., 0., 0., 0.]]) 47 | 48 | """ 49 | max_len = max([len(item) for item in xs]) 50 | batchs = len(xs) 51 | ndim = xs[0].ndim 52 | if ndim == 1: 53 | pad_res = torch.zeros(batchs, 54 | max_len, 55 | dtype=xs[0].dtype, 56 | device=xs[0].device) 57 | elif ndim == 2: 58 | pad_res = torch.zeros(batchs, 59 | max_len, 60 | xs[0].shape[1], 61 | dtype=xs[0].dtype, 62 | device=xs[0].device) 63 | elif ndim == 3: 64 | pad_res = torch.zeros(batchs, 65 | max_len, 66 | xs[0].shape[1], 67 | xs[0].shape[2], 68 | dtype=xs[0].dtype, 69 | device=xs[0].device) 70 | else: 71 | raise ValueError(f"Unsupported ndim: {ndim}") 72 | pad_res.fill_(pad_value) 73 | for i in range(batchs): 74 | pad_res[i, :len(xs[i])] = xs[i] 75 | return pad_res 76 | 77 | 78 | def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, 79 | ignore_label: int) -> torch.Tensor: 80 | """Calculate accuracy. 81 | 82 | Args: 83 | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). 84 | pad_targets (LongTensor): Target label tensors (B, Lmax). 85 | ignore_label (int): Ignore label id. 86 | 87 | Returns: 88 | torch.Tensor: Accuracy value (0.0 - 1.0). 89 | 90 | """ 91 | pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), 92 | pad_outputs.size(1)).argmax(2) 93 | mask = pad_targets != ignore_label 94 | numerator = torch.sum( 95 | pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) 96 | denominator = torch.sum(mask) 97 | return (numerator / denominator).detach() 98 | 99 | 100 | def get_padding(kernel_size, dilation=1): 101 | return int((kernel_size * dilation - dilation) / 2) 102 | 103 | 104 | def init_weights(m, mean=0.0, std=0.01): 105 | classname = m.__class__.__name__ 106 | if classname.find("Conv") != -1: 107 | m.weight.data.normal_(mean, std) 108 | 109 | 110 | # Repetition Aware Sampling in VALL-E 2 111 | def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1): 112 | top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k) 113 | rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item() 114 | if rep_num >= win_size * tau_r: 115 | top_ids = random_sampling(weighted_scores, decoded_tokens, sampling) 116 | return top_ids 117 | 118 | 119 | def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25): 120 | prob, indices = [], [] 121 | cum_prob = 0.0 122 | sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True) 123 | for i in range(len(sorted_idx)): 124 | # sampling both top-p and numbers. 125 | if cum_prob < top_p and len(prob) < top_k: 126 | cum_prob += sorted_value[i] 127 | prob.append(sorted_value[i]) 128 | indices.append(sorted_idx[i]) 129 | else: 130 | break 131 | prob = torch.tensor(prob).to(weighted_scores) 132 | indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device) 133 | top_ids = indices[prob.multinomial(1, replacement=True)] 134 | return top_ids 135 | 136 | 137 | def random_sampling(weighted_scores, decoded_tokens, sampling): 138 | top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True) 139 | return top_ids 140 | 141 | 142 | def fade_in_out(fade_in_mel, fade_out_mel, window): 143 | device = fade_in_mel.device 144 | fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu() 145 | mel_overlap_len = int(window.shape[0] / 2) 146 | if fade_in_mel.device == torch.device('cpu'): 147 | fade_in_mel = fade_in_mel.clone() 148 | fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \ 149 | fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:] 150 | return fade_in_mel.to(device) 151 | 152 | 153 | def set_all_random_seed(seed): 154 | random.seed(seed) 155 | np.random.seed(seed) 156 | torch.manual_seed(seed) 157 | torch.cuda.manual_seed_all(seed) 158 | 159 | 160 | def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: 161 | assert mask.dtype == torch.bool 162 | assert dtype in [torch.float32, torch.bfloat16, torch.float16] 163 | mask = mask.to(dtype) 164 | # attention mask bias 165 | # NOTE(Mddct): torch.finfo jit issues 166 | # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min 167 | mask = (1.0 - mask) * -1.0e+10 168 | return mask 169 | 170 | 171 | class TrtContextWrapper: 172 | def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'): 173 | self.trt_context_pool = queue.Queue(maxsize=trt_concurrent) 174 | self.trt_engine = trt_engine 175 | for _ in range(trt_concurrent): 176 | trt_context = trt_engine.create_execution_context() 177 | trt_stream = torch.cuda.stream(torch.cuda.Stream(device)) 178 | assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent) 179 | self.trt_context_pool.put([trt_context, trt_stream]) 180 | assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context' 181 | 182 | def acquire_estimator(self): 183 | return self.trt_context_pool.get(), self.trt_engine 184 | 185 | def release_estimator(self, context, stream): 186 | self.trt_context_pool.put([context, stream]) 187 | -------------------------------------------------------------------------------- /examples/libritts/cosyvoice2/conf/cosyvoice2.yaml: -------------------------------------------------------------------------------- 1 | # set random seed, so that you may reproduce your result. 2 | __set_seed1: !apply:random.seed [1986] 3 | __set_seed2: !apply:numpy.random.seed [1986] 4 | __set_seed3: !apply:torch.manual_seed [1986] 5 | __set_seed4: !apply:torch.cuda.manual_seed_all [1986] 6 | 7 | # fixed params 8 | sample_rate: 24000 9 | llm_input_size: 896 10 | llm_output_size: 896 11 | spk_embed_dim: 192 12 | qwen_pretrain_path: '' 13 | token_frame_rate: 25 14 | token_mel_ratio: 2 15 | 16 | # stream related params 17 | chunk_size: 25 # streaming inference chunk size, in token 18 | num_decoding_left_chunks: -1 # streaming inference flow decoder left chunk size, <0 means use all left chunks 19 | 20 | # model params 21 | # for all class/function included in this repo, we use ! or ! for intialization, so that user may find all corresponding class/function according to one single yaml. 22 | # for system/third_party class/function, we do not require this. 23 | llm: !new:cosyvoice.llm.llm.Qwen2LM 24 | llm_input_size: !ref 25 | llm_output_size: !ref 26 | speech_token_size: 6561 27 | length_normalized_loss: True 28 | lsm_weight: 0 29 | mix_ratio: [5, 15] 30 | llm: !new:cosyvoice.llm.llm.Qwen2Encoder 31 | pretrain_path: !ref 32 | sampling: !name:cosyvoice.utils.common.ras_sampling 33 | top_p: 0.8 34 | top_k: 25 35 | win_size: 10 36 | tau_r: 0.1 37 | 38 | flow: !new:cosyvoice.flow.flow.CausalMaskedDiffWithXvec 39 | input_size: 512 40 | output_size: 80 41 | spk_embed_dim: !ref 42 | output_type: 'mel' 43 | vocab_size: 6561 44 | input_frame_rate: !ref 45 | only_mask_loss: True 46 | token_mel_ratio: !ref 47 | pre_lookahead_len: 3 48 | encoder: !new:cosyvoice.transformer.upsample_encoder.UpsampleConformerEncoder 49 | output_size: 512 50 | attention_heads: 8 51 | linear_units: 2048 52 | num_blocks: 6 53 | dropout_rate: 0.1 54 | positional_dropout_rate: 0.1 55 | attention_dropout_rate: 0.1 56 | normalize_before: True 57 | input_layer: 'linear' 58 | pos_enc_layer_type: 'rel_pos_espnet' 59 | selfattention_layer_type: 'rel_selfattn' 60 | input_size: 512 61 | use_cnn_module: False 62 | macaron_style: False 63 | static_chunk_size: !ref 64 | decoder: !new:cosyvoice.flow.flow_matching.CausalConditionalCFM 65 | in_channels: 240 66 | n_spks: 1 67 | spk_emb_dim: 80 68 | cfm_params: !new:omegaconf.DictConfig 69 | content: 70 | sigma_min: 1e-06 71 | solver: 'euler' 72 | t_scheduler: 'cosine' 73 | training_cfg_rate: 0.2 74 | inference_cfg_rate: 0.7 75 | reg_loss_type: 'l1' 76 | estimator: !new:cosyvoice.flow.decoder.CausalConditionalDecoder 77 | in_channels: 320 78 | out_channels: 80 79 | channels: [256] 80 | dropout: 0.0 81 | attention_head_dim: 64 82 | n_blocks: 4 83 | num_mid_blocks: 12 84 | num_heads: 8 85 | act_fn: 'gelu' 86 | static_chunk_size: !ref * 87 | num_decoding_left_chunks: !ref 88 | 89 | hift: !new:cosyvoice.hifigan.generator.HiFTGenerator 90 | in_channels: 80 91 | base_channels: 512 92 | nb_harmonics: 8 93 | sampling_rate: !ref 94 | nsf_alpha: 0.1 95 | nsf_sigma: 0.003 96 | nsf_voiced_threshold: 10 97 | upsample_rates: [8, 5, 3] 98 | upsample_kernel_sizes: [16, 11, 7] 99 | istft_params: 100 | n_fft: 16 101 | hop_len: 4 102 | resblock_kernel_sizes: [3, 7, 11] 103 | resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 104 | source_resblock_kernel_sizes: [7, 7, 11] 105 | source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]] 106 | lrelu_slope: 0.1 107 | audio_limit: 0.99 108 | f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor 109 | num_class: 1 110 | in_channels: 80 111 | cond_channels: 512 112 | 113 | # gan related module 114 | mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram 115 | n_fft: 1920 116 | num_mels: 80 117 | sampling_rate: !ref 118 | hop_size: 480 119 | win_size: 1920 120 | fmin: 0 121 | fmax: null 122 | center: False 123 | hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan 124 | generator: !ref 125 | discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator 126 | mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator 127 | mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator 128 | mel_spec_transform: [ 129 | !ref 130 | ] 131 | 132 | # processor functions 133 | parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener 134 | get_tokenizer: !name:cosyvoice.tokenizer.tokenizer.get_qwen_tokenizer 135 | token_path: !ref 136 | skip_special_tokens: True 137 | allowed_special: 'all' 138 | tokenize: !name:cosyvoice.dataset.processor.tokenize 139 | get_tokenizer: !ref 140 | allowed_special: !ref 141 | filter: !name:cosyvoice.dataset.processor.filter 142 | max_length: 40960 143 | min_length: 100 144 | token_max_length: 200 145 | token_min_length: 1 146 | resample: !name:cosyvoice.dataset.processor.resample 147 | resample_rate: !ref 148 | truncate: !name:cosyvoice.dataset.processor.truncate 149 | truncate_length: 24480 # must be a multiplier of hop_size 150 | feat_extractor: !name:matcha.utils.audio.mel_spectrogram 151 | n_fft: 1920 152 | num_mels: 80 153 | sampling_rate: !ref 154 | hop_size: 480 155 | win_size: 1920 156 | fmin: 0 157 | fmax: 8000 158 | center: False 159 | compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank 160 | feat_extractor: !ref 161 | token_mel_ratio: 2 162 | compute_f0: !name:cosyvoice.dataset.processor.compute_f0 163 | sample_rate: !ref 164 | hop_size: 480 165 | parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding 166 | normalize: True 167 | shuffle: !name:cosyvoice.dataset.processor.shuffle 168 | shuffle_size: 1000 169 | sort: !name:cosyvoice.dataset.processor.sort 170 | sort_size: 500 # sort_size should be less than shuffle_size 171 | batch: !name:cosyvoice.dataset.processor.batch 172 | batch_type: 'dynamic' 173 | max_frames_in_batch: 2000 174 | padding: !name:cosyvoice.dataset.processor.padding 175 | use_spk_embedding: False # change to True during sft 176 | 177 | 178 | # dataset processor pipeline 179 | data_pipeline: [ 180 | !ref , 181 | !ref , 182 | !ref , 183 | !ref , 184 | !ref , 185 | !ref , 186 | !ref , 187 | !ref , 188 | !ref , 189 | !ref , 190 | ] 191 | data_pipeline_gan: [ 192 | !ref , 193 | !ref , 194 | !ref , 195 | !ref , 196 | !ref , 197 | !ref , 198 | !ref , 199 | !ref , 200 | !ref , 201 | !ref , 202 | !ref , 203 | !ref , 204 | ] 205 | 206 | # llm flow train conf 207 | train_conf: 208 | optim: adam 209 | optim_conf: 210 | lr: 1e-5 # change to 1e-5 during sft 211 | scheduler: constantlr # change to constantlr during sft 212 | scheduler_conf: 213 | warmup_steps: 2500 214 | max_epoch: 200 215 | grad_clip: 5 216 | accum_grad: 2 217 | log_interval: 100 218 | save_per_step: -1 219 | 220 | # gan train conf 221 | train_conf_gan: 222 | optim: adam 223 | optim_conf: 224 | lr: 0.0002 # use small lr for gan training 225 | scheduler: constantlr 226 | optim_d: adam 227 | optim_conf_d: 228 | lr: 0.0002 # use small lr for gan training 229 | scheduler_d: constantlr 230 | max_epoch: 200 231 | grad_clip: 5 232 | accum_grad: 1 # in gan training, accum_grad must be 1 233 | log_interval: 100 234 | save_per_step: -1 -------------------------------------------------------------------------------- /cosyvoice/tokenizer/tokenizer.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | from functools import lru_cache 4 | from typing import Optional 5 | import torch 6 | from transformers import AutoTokenizer 7 | from whisper.tokenizer import Tokenizer 8 | 9 | import tiktoken 10 | 11 | LANGUAGES = { 12 | "en": "english", 13 | "zh": "chinese", 14 | "de": "german", 15 | "es": "spanish", 16 | "ru": "russian", 17 | "ko": "korean", 18 | "fr": "french", 19 | "ja": "japanese", 20 | "pt": "portuguese", 21 | "tr": "turkish", 22 | "pl": "polish", 23 | "ca": "catalan", 24 | "nl": "dutch", 25 | "ar": "arabic", 26 | "sv": "swedish", 27 | "it": "italian", 28 | "id": "indonesian", 29 | "hi": "hindi", 30 | "fi": "finnish", 31 | "vi": "vietnamese", 32 | "he": "hebrew", 33 | "uk": "ukrainian", 34 | "el": "greek", 35 | "ms": "malay", 36 | "cs": "czech", 37 | "ro": "romanian", 38 | "da": "danish", 39 | "hu": "hungarian", 40 | "ta": "tamil", 41 | "no": "norwegian", 42 | "th": "thai", 43 | "ur": "urdu", 44 | "hr": "croatian", 45 | "bg": "bulgarian", 46 | "lt": "lithuanian", 47 | "la": "latin", 48 | "mi": "maori", 49 | "ml": "malayalam", 50 | "cy": "welsh", 51 | "sk": "slovak", 52 | "te": "telugu", 53 | "fa": "persian", 54 | "lv": "latvian", 55 | "bn": "bengali", 56 | "sr": "serbian", 57 | "az": "azerbaijani", 58 | "sl": "slovenian", 59 | "kn": "kannada", 60 | "et": "estonian", 61 | "mk": "macedonian", 62 | "br": "breton", 63 | "eu": "basque", 64 | "is": "icelandic", 65 | "hy": "armenian", 66 | "ne": "nepali", 67 | "mn": "mongolian", 68 | "bs": "bosnian", 69 | "kk": "kazakh", 70 | "sq": "albanian", 71 | "sw": "swahili", 72 | "gl": "galician", 73 | "mr": "marathi", 74 | "pa": "punjabi", 75 | "si": "sinhala", 76 | "km": "khmer", 77 | "sn": "shona", 78 | "yo": "yoruba", 79 | "so": "somali", 80 | "af": "afrikaans", 81 | "oc": "occitan", 82 | "ka": "georgian", 83 | "be": "belarusian", 84 | "tg": "tajik", 85 | "sd": "sindhi", 86 | "gu": "gujarati", 87 | "am": "amharic", 88 | "yi": "yiddish", 89 | "lo": "lao", 90 | "uz": "uzbek", 91 | "fo": "faroese", 92 | "ht": "haitian creole", 93 | "ps": "pashto", 94 | "tk": "turkmen", 95 | "nn": "nynorsk", 96 | "mt": "maltese", 97 | "sa": "sanskrit", 98 | "lb": "luxembourgish", 99 | "my": "myanmar", 100 | "bo": "tibetan", 101 | "tl": "tagalog", 102 | "mg": "malagasy", 103 | "as": "assamese", 104 | "tt": "tatar", 105 | "haw": "hawaiian", 106 | "ln": "lingala", 107 | "ha": "hausa", 108 | "ba": "bashkir", 109 | "jw": "javanese", 110 | "su": "sundanese", 111 | "yue": "cantonese", 112 | "minnan": "minnan", 113 | "wuyu": "wuyu", 114 | "dialect": "dialect", 115 | "zh/en": "zh/en", 116 | "en/zh": "en/zh", 117 | } 118 | 119 | # language code lookup by name, with a few language aliases 120 | TO_LANGUAGE_CODE = { 121 | **{language: code for code, language in LANGUAGES.items()}, 122 | "burmese": "my", 123 | "valencian": "ca", 124 | "flemish": "nl", 125 | "haitian": "ht", 126 | "letzeburgesch": "lb", 127 | "pushto": "ps", 128 | "panjabi": "pa", 129 | "moldavian": "ro", 130 | "moldovan": "ro", 131 | "sinhalese": "si", 132 | "castilian": "es", 133 | "mandarin": "zh", 134 | } 135 | 136 | AUDIO_EVENT = { 137 | "ASR": "ASR", 138 | "AED": "AED", 139 | "SER": "SER", 140 | "Speech": "Speech", 141 | "/Speech": "/Speech", 142 | "BGM": "BGM", 143 | "/BGM": "/BGM", 144 | "Laughter": "Laughter", 145 | "/Laughter": "/Laughter", 146 | "Applause": "Applause", 147 | "/Applause": "/Applause", 148 | } 149 | 150 | EMOTION = { 151 | "HAPPY": "HAPPY", 152 | "SAD": "SAD", 153 | "ANGRY": "ANGRY", 154 | "NEUTRAL": "NEUTRAL", 155 | } 156 | 157 | TTS_Vocal_Token = { 158 | "TTS/B": "TTS/B", 159 | "TTS/O": "TTS/O", 160 | "TTS/Q": "TTS/Q", 161 | "TTS/A": "TTS/A", 162 | "TTS/CO": "TTS/CO", 163 | "TTS/CL": "TTS/CL", 164 | "TTS/H": "TTS/H", 165 | **{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)} 166 | } 167 | 168 | 169 | @lru_cache(maxsize=None) 170 | def get_encoding(name: str = "gpt2", num_languages: int = 99): 171 | vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken") 172 | ranks = { 173 | base64.b64decode(token): int(rank) 174 | for token, rank in (line.split() for line in open(vocab_path) if line) 175 | } 176 | n_vocab = len(ranks) 177 | special_tokens = {} 178 | 179 | specials = [ 180 | "<|endoftext|>", 181 | "<|startoftranscript|>", 182 | *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], 183 | *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())], 184 | *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())], 185 | "<|translate|>", 186 | "<|transcribe|>", 187 | "<|startoflm|>", 188 | "<|startofprev|>", 189 | "<|nospeech|>", 190 | "<|notimestamps|>", 191 | *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR 192 | *[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS 193 | *[f"<|{i * 0.02:.2f}|>" for i in range(1501)], 194 | ] 195 | 196 | for token in specials: 197 | special_tokens[token] = n_vocab 198 | n_vocab += 1 199 | 200 | return tiktoken.Encoding( 201 | name=os.path.basename(vocab_path), 202 | explicit_n_vocab=n_vocab, 203 | pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", 204 | mergeable_ranks=ranks, 205 | special_tokens=special_tokens, 206 | ) 207 | 208 | 209 | @lru_cache(maxsize=None) 210 | def get_tokenizer( 211 | multilingual: bool, 212 | *, 213 | num_languages: int = 99, 214 | language: Optional[str] = None, 215 | task: Optional[str] = None, # Literal["transcribe", "translate", None] 216 | ) -> Tokenizer: 217 | if language is not None: 218 | language = language.lower() 219 | if language not in LANGUAGES: 220 | if language in TO_LANGUAGE_CODE: 221 | language = TO_LANGUAGE_CODE[language] 222 | else: 223 | raise ValueError(f"Unsupported language: {language}") 224 | 225 | if multilingual: 226 | encoding_name = "multilingual_zh_ja_yue_char_del" 227 | language = language or "en" 228 | task = task or "transcribe" 229 | else: 230 | encoding_name = "gpt2" 231 | language = None 232 | task = None 233 | 234 | encoding = get_encoding(name=encoding_name, num_languages=num_languages) 235 | 236 | return Tokenizer( 237 | encoding=encoding, num_languages=num_languages, language=language, task=task 238 | ) 239 | 240 | 241 | class QwenTokenizer(): 242 | def __init__(self, token_path, skip_special_tokens=True): 243 | super().__init__() 244 | # NOTE: non-chat model, all these special tokens keep randomly initialized. 245 | special_tokens = { 246 | 'eos_token': '<|endoftext|>', 247 | 'pad_token': '<|endoftext|>', 248 | 'additional_special_tokens': [ 249 | '<|im_start|>', '<|im_end|>', '<|endofprompt|>', 250 | '[breath]', '', '', '[noise]', 251 | '[laughter]', '[cough]', '[clucking]', '[accent]', 252 | '[quick_breath]', 253 | "", "", 254 | "[hissing]", "[sigh]", "[vocalized-noise]", 255 | "[lipsmack]", "[mn]" 256 | ] 257 | } 258 | self.special_tokens = special_tokens 259 | self.tokenizer = AutoTokenizer.from_pretrained(token_path) 260 | self.tokenizer.add_special_tokens(special_tokens) 261 | self.skip_special_tokens = skip_special_tokens 262 | 263 | def encode(self, text, **kwargs): 264 | tokens = self.tokenizer([text], return_tensors="pt") 265 | tokens = tokens["input_ids"][0].cpu().tolist() 266 | return tokens 267 | 268 | def decode(self, tokens): 269 | tokens = torch.tensor(tokens, dtype=torch.int64) 270 | text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0] 271 | return text 272 | 273 | 274 | @lru_cache(maxsize=None) 275 | def get_qwen_tokenizer( 276 | token_path: str, 277 | skip_special_tokens: bool 278 | ) -> QwenTokenizer: 279 | return QwenTokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens) 280 | --------------------------------------------------------------------------------