├── cosyvoice ├── __init__.py ├── cli │ └── __init__.py ├── utils │ ├── __init__.py │ ├── losses.py │ ├── class_utils.py │ └── frontend_utils.py ├── dataset │ ├── __init__.py │ ├── my_processor.py │ └── dataset.py ├── transformer │ ├── __init__.py │ ├── activation.py │ ├── label_smoothing_loss.py │ ├── positionwise_feed_forward.py │ ├── decoder_layer.py │ └── convolution.py ├── hifigan │ ├── f0_predictor.py │ └── hifigan.py ├── flow │ └── length_regulator.py ├── bin │ ├── average_model.py │ ├── export_jit.py │ └── export_onnx.py ├── llm │ └── reward_tts.py └── vllm │ └── cosyvoice2.py ├── third_party └── Matcha-TTS │ ├── matcha │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ └── components │ │ │ ├── __init__.py │ │ │ └── flow_matching.py │ ├── onnx │ │ └── __init__.py │ ├── VERSION │ ├── hifigan │ │ ├── __init__.py │ │ ├── env.py │ │ ├── config.py │ │ ├── LICENSE │ │ ├── xutils.py │ │ ├── denoiser.py │ │ └── README.md │ ├── utils │ │ ├── monotonic_align │ │ │ ├── setup.py │ │ │ ├── __init__.py │ │ │ └── core.pyx │ │ ├── __init__.py │ │ ├── pylogger.py │ │ ├── logging_utils.py │ │ ├── instantiators.py │ │ ├── audio.py │ │ ├── model.py │ │ ├── rich_utils.py │ │ └── generate_data_statistics.py │ ├── text │ │ ├── symbols.py │ │ ├── __init__.py │ │ ├── numbers.py │ │ └── cleaners.py │ └── train.py │ ├── notebooks │ └── .gitkeep │ ├── configs │ ├── local │ │ └── .gitkeep │ ├── callbacks │ │ ├── none.yaml │ │ ├── default.yaml │ │ ├── rich_progress_bar.yaml │ │ ├── model_summary.yaml │ │ └── model_checkpoint.yaml │ ├── model │ │ ├── cfm │ │ │ └── default.yaml │ │ ├── optimizer │ │ │ └── adam.yaml │ │ ├── decoder │ │ │ └── default.yaml │ │ ├── matcha.yaml │ │ └── encoder │ │ │ └── default.yaml │ ├── trainer │ │ ├── cpu.yaml │ │ ├── gpu.yaml │ │ ├── mps.yaml │ │ ├── ddp.yaml │ │ ├── ddp_sim.yaml │ │ └── default.yaml │ ├── __init__.py │ ├── debug │ │ ├── fdr.yaml │ │ ├── overfit.yaml │ │ ├── limit.yaml │ │ ├── profiler.yaml │ │ └── default.yaml │ ├── logger │ │ ├── many_loggers.yaml │ │ ├── csv.yaml │ │ ├── tensorboard.yaml │ │ ├── neptune.yaml │ │ ├── mlflow.yaml │ │ ├── comet.yaml │ │ ├── wandb.yaml │ │ └── aim.yaml │ ├── extras │ │ └── default.yaml │ ├── experiment │ │ ├── ljspeech.yaml │ │ ├── multispeaker.yaml │ │ ├── ljspeech_min_memory.yaml │ │ └── hifi_dataset_piper_phonemizer.yaml │ ├── eval.yaml │ ├── hydra │ │ └── default.yaml │ ├── paths │ │ └── default.yaml │ ├── train.yaml │ └── hparams_search │ │ └── mnist_optuna.yaml │ ├── .project-root │ ├── scripts │ └── schedule.sh │ ├── .env.example │ ├── MANIFEST.in │ ├── .github │ ├── codecov.yml │ ├── dependabot.yml │ ├── PULL_REQUEST_TEMPLATE.md │ └── release-drafter.yml │ ├── requirements.txt │ ├── LICENSE │ ├── pyproject.toml │ ├── Makefile │ ├── setup.py │ ├── .pre-commit-config.yaml │ └── .gitignore ├── asset └── dingding.png ├── download_pre_models.sh ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── FAQ.md ├── runtime └── python │ ├── Dockerfile │ ├── grpc │ ├── cosyvoice.proto │ ├── server.py │ └── client.py │ └── fastapi │ ├── client.py │ └── server.py ├── .gitignore ├── vllm_example.py ├── README.md ├── requirements.txt ├── docker └── Dockerfile ├── tools ├── extract_speech_token.py ├── collect_spk_embedding_fast.py ├── extract_embedding.py ├── parse_options.sh └── make_parquet_list.py └── CODE_OF_CONDUCT.md /cosyvoice/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/cli/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cosyvoice/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/notebooks/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/local/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/onnx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/none.yaml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/VERSION: -------------------------------------------------------------------------------- 1 | 0.0.5.1 2 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /asset/dingding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ryuclc/CosyVoice2-GRPO/HEAD/asset/dingding.png -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/cfm/default.yaml: -------------------------------------------------------------------------------- 1 | name: CFM 2 | solver: euler 3 | sigma_min: 1e-4 4 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/cpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: cpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/gpu.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: gpu 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/mps.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | accelerator: mps 5 | devices: 1 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # this file is needed here to include configs when building project as a package 2 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | _partial_: true 3 | lr: 1e-4 4 | weight_decay: 0.0 5 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model_checkpoint.yaml 3 | - model_summary.yaml 4 | - rich_progress_bar.yaml 5 | - _self_ 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/ddp.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | strategy: ddp 5 | 6 | accelerator: gpu 7 | devices: [0,1] 8 | num_nodes: 1 9 | sync_batchnorm: True 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/ddp_sim.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | 4 | # simulate DDP on CPU, useful for debugging 5 | accelerator: cpu 6 | devices: 2 7 | strategy: ddp_spawn 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/decoder/default.yaml: -------------------------------------------------------------------------------- 1 | channels: [256, 256] 2 | dropout: 0.05 3 | attention_head_dim: 64 4 | n_blocks: 1 5 | num_mid_blocks: 2 6 | num_heads: 2 7 | act_fn: snakebeta 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/fdr.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs 1 train, 1 validation and 1 test step 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | fast_dev_run: true 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/many_loggers.yaml: -------------------------------------------------------------------------------- 1 | # train with many loggers at once 2 | 3 | defaults: 4 | # - comet 5 | - csv 6 | # - mlflow 7 | # - neptune 8 | - tensorboard 9 | - wandb 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/csv.yaml: -------------------------------------------------------------------------------- 1 | # csv logger built in lightning 2 | 3 | csv: 4 | _target_: lightning.pytorch.loggers.csv_logs.CSVLogger 5 | save_dir: "${paths.output_dir}" 6 | name: "csv/" 7 | prefix: "" 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/rich_progress_bar.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/latest/api/lightning.pytorch.callbacks.RichProgressBar.html 2 | 3 | rich_progress_bar: 4 | _target_: lightning.pytorch.callbacks.RichProgressBar 5 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/scripts/schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Schedule execution of many runs 3 | # Run from root folder with: bash scripts/schedule.sh 4 | 5 | python src/train.py trainer.max_epochs=5 logger=csv 6 | 7 | python src/train.py trainer.max_epochs=10 logger=csv 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/setup.py: -------------------------------------------------------------------------------- 1 | # from distutils.core import setup 2 | # from Cython.Build import cythonize 3 | # import numpy 4 | 5 | # setup(name='monotonic_align', 6 | # ext_modules=cythonize("core.pyx"), 7 | # include_dirs=[numpy.get_include()]) 8 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/extras/default.yaml: -------------------------------------------------------------------------------- 1 | # disable python warnings if they annoy you 2 | ignore_warnings: False 3 | 4 | # ask user for tags if none are provided in the config 5 | enforce_tags: True 6 | 7 | # pretty print config tree at the start of the run using Rich library 8 | print_config: True 9 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # overfits to 3 batches 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 20 10 | overfit_batches: 3 11 | 12 | # model ckpt and early stopping need to be disabled during overfitting 13 | callbacks: null 14 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/limit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # uses only 1% of the training data and 5% of validation/test data 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 3 10 | limit_train_batches: 0.01 11 | limit_val_batches: 0.05 12 | limit_test_batches: 0.05 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/model_summary.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html 2 | 3 | model_summary: 4 | _target_: lightning.pytorch.callbacks.RichModelSummary 5 | max_depth: 3 # the maximum depth of layer nesting that the summary will include 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/profiler.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # runs with execution time profiling 4 | 5 | defaults: 6 | - default 7 | 8 | trainer: 9 | max_epochs: 1 10 | # profiler: "simple" 11 | profiler: "advanced" 12 | # profiler: "pytorch" 13 | accelerator: gpu 14 | 15 | limit_train_batches: 0.02 16 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/tensorboard.yaml: -------------------------------------------------------------------------------- 1 | # https://www.tensorflow.org/tensorboard/ 2 | 3 | tensorboard: 4 | _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger 5 | save_dir: "${paths.output_dir}/tensorboard/" 6 | name: null 7 | log_graph: False 8 | default_hp_metric: True 9 | prefix: "" 10 | # version: "" 11 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/neptune.yaml: -------------------------------------------------------------------------------- 1 | # https://neptune.ai 2 | 3 | neptune: 4 | _target_: lightning.pytorch.loggers.neptune.NeptuneLogger 5 | api_key: ${oc.env:NEPTUNE_API_TOKEN} # api key is loaded from environment variable 6 | project: username/lightning-hydra-template 7 | # name: "" 8 | log_model_checkpoints: True 9 | prefix: "" 10 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from matcha.utils.instantiators import instantiate_callbacks, instantiate_loggers 2 | from matcha.utils.logging_utils import log_hyperparameters 3 | from matcha.utils.pylogger import get_pylogger 4 | from matcha.utils.rich_utils import enforce_tags, print_config_tree 5 | from matcha.utils.utils import extras, get_metric_value, task_wrapper 6 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | MY_VAR="/home/user/my/system/path" 7 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/mlflow.yaml: -------------------------------------------------------------------------------- 1 | # https://mlflow.org 2 | 3 | mlflow: 4 | _target_: lightning.pytorch.loggers.mlflow.MLFlowLogger 5 | # experiment_name: "" 6 | # run_name: "" 7 | tracking_uri: ${paths.log_dir}/mlflow/mlruns # run `mlflow ui` command inside the `logs/mlflow/` dir to open the UI 8 | tags: null 9 | # save_dir: "./mlruns" 10 | prefix: "" 11 | artifact_location: null 12 | # run_id: "" 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/matcha.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - encoder: default.yaml 4 | - decoder: default.yaml 5 | - cfm: default.yaml 6 | - optimizer: adam.yaml 7 | 8 | _target_: matcha.models.matcha_tts.MatchaTTS 9 | n_vocab: 178 10 | n_spks: ${data.n_spks} 11 | spk_emb_dim: 64 12 | n_feats: 80 13 | data_statistics: ${data.data_statistics} 14 | out_size: null # Must be divisible by 4 15 | prior_loss: true 16 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE.txt 3 | include requirements.*.txt 4 | include *.cff 5 | include requirements.txt 6 | include matcha/VERSION 7 | recursive-include matcha *.json 8 | recursive-include matcha *.html 9 | recursive-include matcha *.png 10 | recursive-include matcha *.md 11 | recursive-include matcha *.py 12 | recursive-include matcha *.pyx 13 | recursive-exclude tests * 14 | prune tests* 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/ljspeech.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - data: mnist # choose datamodule with `test_dataloader()` for evaluation 6 | - model: mnist 7 | - logger: null 8 | - trainer: default 9 | - paths: default 10 | - extras: default 11 | - hydra: default 12 | 13 | task_name: "eval" 14 | 15 | tags: ["dev"] 16 | 17 | # passing checkpoint path is necessary for evaluation 18 | ckpt_path: ??? 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/multispeaker.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: vctk.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["multispeaker"] 13 | 14 | run_name: multispeaker 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/comet.yaml: -------------------------------------------------------------------------------- 1 | # https://www.comet.ml 2 | 3 | comet: 4 | _target_: lightning.pytorch.loggers.comet.CometLogger 5 | api_key: ${oc.env:COMET_API_TOKEN} # api key is loaded from environment variable 6 | save_dir: "${paths.output_dir}" 7 | project_name: "lightning-hydra-template" 8 | rest_api_key: null 9 | # experiment_name: "" 10 | experiment_key: null # set to resume experiment 11 | offline: False 12 | prefix: "" 13 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.github/codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | # measures overall project coverage 4 | project: 5 | default: 6 | threshold: 100% # how much decrease in coverage is needed to not consider success 7 | 8 | # measures PR or single commit coverage 9 | patch: 10 | default: 11 | threshold: 100% # how much decrease in coverage is needed to not consider success 12 | 13 | 14 | # project: off 15 | # patch: off 16 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/ljspeech_min_memory.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: ljspeech.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["ljspeech"] 13 | 14 | run_name: ljspeech_min 15 | 16 | 17 | model: 18 | out_size: 172 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/env.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import os 4 | import shutil 5 | 6 | 7 | class AttrDict(dict): 8 | def __init__(self, *args, **kwargs): 9 | super().__init__(*args, **kwargs) 10 | self.__dict__ = self 11 | 12 | 13 | def build_env(config, config_name, path): 14 | t_path = os.path.join(path, config_name) 15 | if config != t_path: 16 | os.makedirs(path, exist_ok=True) 17 | shutil.copyfile(config, os.path.join(path, config_name)) 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/model/encoder/default.yaml: -------------------------------------------------------------------------------- 1 | encoder_type: RoPE Encoder 2 | encoder_params: 3 | n_feats: ${model.n_feats} 4 | n_channels: 192 5 | filter_channels: 768 6 | filter_channels_dp: 256 7 | n_heads: 2 8 | n_layers: 6 9 | kernel_size: 3 10 | p_dropout: 0.1 11 | spk_emb_dim: 64 12 | n_spks: 1 13 | prenet: true 14 | 15 | duration_predictor_params: 16 | filter_channels_dp: ${model.encoder.encoder_params.filter_channels_dp} 17 | kernel_size: 3 18 | p_dropout: ${model.encoder.encoder_params.p_dropout} 19 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/experiment/hifi_dataset_piper_phonemizer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # to execute this experiment run: 4 | # python train.py experiment=multispeaker 5 | 6 | defaults: 7 | - override /data: hi-fi_en-US_female.yaml 8 | 9 | # all parameters below will be merged with parameters from default configurations set above 10 | # this allows you to overwrite only specified parameters 11 | 12 | tags: ["hi-fi", "single_speaker", "piper_phonemizer", "en_US", "female"] 13 | 14 | run_name: hi-fi_en-US_female_piper_phonemizer 15 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | _target_: lightning.pytorch.trainer.Trainer 2 | 3 | default_root_dir: ${paths.output_dir} 4 | 5 | max_epochs: -1 6 | 7 | accelerator: gpu 8 | devices: [0] 9 | 10 | # mixed precision for extra speed-up 11 | precision: 16-mixed 12 | 13 | # perform a validation loop every N training epochs 14 | check_val_every_n_epoch: 1 15 | 16 | # set True to to ensure deterministic results 17 | # makes training slower but gives more reproducibility than just setting seeds 18 | deterministic: False 19 | 20 | gradient_clip_val: 5.0 21 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/symbols.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Defines the set of symbols used in text input to the model. 4 | """ 5 | _pad = "_" 6 | _punctuation = ';:,.!?¡¿—…"«»“” ' 7 | _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" 8 | _letters_ipa = ( 9 | "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ" 10 | ) 11 | 12 | 13 | # Export all symbols: 14 | symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa) 15 | 16 | # Special symbol ids 17 | SPACE_ID = symbols.index(" ") 18 | -------------------------------------------------------------------------------- /download_pre_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define the base URL 4 | BASE_URL="https://www.modelscope.cn/iic/" 5 | 6 | # List of model names 7 | model_names=( 8 | # "CosyVoice2-0.5B" 9 | # "CosyVoice-300M" 10 | # "CosyVoice-300M-25Hz" 11 | # "CosyVoice-300M-SFT" 12 | # "CosyVoice-300M-Instruct" 13 | "CosyVoice-ttsfrd" 14 | ) 15 | 16 | mkdir -p pretrained_models 17 | git lfs install 18 | # Loop through each model name and clone the repository 19 | for model_name in "${model_names[@]}"; do 20 | echo "Cloning ${model_name}..." 21 | git clone "${BASE_URL}${model_name}.git" pretrained_models/${model_name} 22 | done 23 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # https://wandb.ai 2 | 3 | wandb: 4 | _target_: lightning.pytorch.loggers.wandb.WandbLogger 5 | # name: "" # name of the run (normally generated by wandb) 6 | save_dir: "${paths.output_dir}" 7 | offline: False 8 | id: null # pass correct id to resume experiment! 9 | anonymous: null # enable anonymous logging 10 | project: "lightning-hydra-template" 11 | log_model: False # upload lightning ckpts 12 | prefix: "" # a string to put at the beginning of metric keys 13 | # entity: "" # set to name of your wandb team 14 | group: "" 15 | tags: [] 16 | job_type: "" 17 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override hydra_logging: colorlog 6 | - override job_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${paths.log_dir}/${task_name}/${run_name}/runs/${now:%Y-%m-%d}_${now:%H-%M-%S} 11 | sweep: 12 | dir: ${paths.log_dir}/${task_name}/${run_name}/multiruns/${now:%Y-%m-%d}_${now:%H-%M-%S} 13 | subdir: ${hydra.job.num} 14 | 15 | job_logging: 16 | handlers: 17 | file: 18 | # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242 19 | filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log 20 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ## ModuleNotFoundError: No module named 'matcha' 2 | 3 | Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`. 4 | 5 | run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script. 6 | 7 | ## cannot find resource.zip or cannot unzip resource.zip 8 | 9 | Please make sure you have git-lfs installed. Execute 10 | 11 | ```sh 12 | git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd 13 | cd pretrained_models/CosyVoice-ttsfrd/ 14 | unzip resource.zip -d . 15 | pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl 16 | ``` 17 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/paths/default.yaml: -------------------------------------------------------------------------------- 1 | # path to root directory 2 | # this requires PROJECT_ROOT environment variable to exist 3 | # you can replace it with "." if you want the root to be the current working directory 4 | root_dir: ${oc.env:PROJECT_ROOT} 5 | 6 | # path to data directory 7 | data_dir: ${paths.root_dir}/data/ 8 | 9 | # path to logging directory 10 | log_dir: ${paths.root_dir}/logs/ 11 | 12 | # path to output directory, created dynamically by hydra 13 | # path generation pattern is specified in `configs/hydra/default.yaml` 14 | # use it to store all files generated during the run, like ckpts and metrics 15 | output_dir: ${hydra:runtime.output_dir} 16 | 17 | # path to working directory 18 | work_dir: ${hydra:runtime.cwd} 19 | -------------------------------------------------------------------------------- /runtime/python/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | 4 | WORKDIR /opt/CosyVoice 5 | 6 | RUN sed -i s@/archive.ubuntu.com/@/mirrors.aliyun.com/@g /etc/apt/sources.list 7 | RUN apt-get update -y 8 | RUN apt-get -y install git unzip git-lfs g++ 9 | RUN git lfs install 10 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 11 | # here we use python==3.10 because we cannot find an image which have both python3.8 and torch2.0.1-cu118 installed 12 | RUN cd CosyVoice && pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 13 | RUN cd CosyVoice/runtime/python/grpc && python3 -m grpc_tools.protoc -I. --python_out=. --grpc_python_out=. cosyvoice.proto -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/__init__.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from matcha.utils.monotonic_align.core import maximum_path_c 5 | 6 | 7 | def maximum_path(value, mask): 8 | """Cython optimised version. 9 | value: [b, t_x, t_y] 10 | mask: [b, t_x, t_y] 11 | """ 12 | value = value * mask 13 | device = value.device 14 | dtype = value.dtype 15 | value = value.data.cpu().numpy().astype(np.float32) 16 | path = np.zeros_like(value).astype(np.int32) 17 | mask = mask.data.cpu().numpy() 18 | 19 | t_x_max = mask.sum(1)[:, 0].astype(np.int32) 20 | t_y_max = mask.sum(2)[:, 0].astype(np.int32) 21 | maximum_path_c(path, value, t_x_max, t_y_max) 22 | return torch.from_numpy(path).to(device=device, dtype=dtype) 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Visual Studio Code files 7 | .vscode 8 | .vs 9 | 10 | # PyCharm files 11 | .idea 12 | 13 | # Eclipse Project settings 14 | *.*project 15 | .settings 16 | 17 | # Sublime Text settings 18 | *.sublime-workspace 19 | *.sublime-project 20 | 21 | # Editor temporaries 22 | *.swn 23 | *.swo 24 | *.swp 25 | *.swm 26 | *~ 27 | 28 | # IPython notebook checkpoints 29 | .ipynb_checkpoints 30 | 31 | # macOS dir files 32 | .DS_Store 33 | 34 | exp 35 | data 36 | raw_wav 37 | tensorboard 38 | **/*build* 39 | 40 | # Clangd files 41 | .cache 42 | compile_commands.json 43 | 44 | # train/inference files 45 | *.wav 46 | *.m4a 47 | *.aac 48 | *.pt 49 | pretrained_models/* 50 | *_pb2_grpc.py 51 | *_pb2.py 52 | *.tar -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | target-branch: "dev" 11 | schedule: 12 | interval: "daily" 13 | ignore: 14 | - dependency-name: "pytorch-lightning" 15 | update-types: ["version-update:semver-patch"] 16 | - dependency-name: "torchmetrics" 17 | update-types: ["version-update:semver-patch"] 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/pylogger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | 5 | 6 | def get_pylogger(name: str = __name__) -> logging.Logger: 7 | """Initializes a multi-GPU-friendly python command line logger. 8 | 9 | :param name: The name of the logger, defaults to ``__name__``. 10 | 11 | :return: A logger object. 12 | """ 13 | logger = logging.getLogger(name) 14 | 15 | # this ensures all logging levels get marked with the rank zero decorator 16 | # otherwise logs would get multiplied for each GPU process in multi-GPU setup 17 | logging_levels = ("debug", "info", "warning", "error", "exception", "fatal", "critical") 18 | for level in logging_levels: 19 | setattr(logger, level, rank_zero_only(getattr(logger, level))) 20 | 21 | return logger 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## What does this PR do? 2 | 3 | 9 | 10 | Fixes #\ 11 | 12 | ## Before submitting 13 | 14 | - [ ] Did you make sure **title is self-explanatory** and **the description concisely explains the PR**? 15 | - [ ] Did you make sure your **PR does only one thing**, instead of bundling different changes together? 16 | - [ ] Did you list all the **breaking changes** introduced by this pull request? 17 | - [ ] Did you **test your PR locally** with `pytest` command? 18 | - [ ] Did you **run pre-commit hooks** with `pre-commit run -a` command? 19 | 20 | ## Did you have fun? 21 | 22 | Make sure you had fun coding 🙃 23 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/config.py: -------------------------------------------------------------------------------- 1 | v1 = { 2 | "resblock": "1", 3 | "num_gpus": 0, 4 | "batch_size": 16, 5 | "learning_rate": 0.0004, 6 | "adam_b1": 0.8, 7 | "adam_b2": 0.99, 8 | "lr_decay": 0.999, 9 | "seed": 1234, 10 | "upsample_rates": [8, 8, 2, 2], 11 | "upsample_kernel_sizes": [16, 16, 4, 4], 12 | "upsample_initial_channel": 512, 13 | "resblock_kernel_sizes": [3, 7, 11], 14 | "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], 15 | "resblock_initial_channel": 256, 16 | "segment_size": 8192, 17 | "num_mels": 80, 18 | "num_freq": 1025, 19 | "n_fft": 1024, 20 | "hop_size": 256, 21 | "win_size": 1024, 22 | "sampling_rate": 22050, 23 | "fmin": 0, 24 | "fmax": 8000, 25 | "fmax_loss": None, 26 | "num_workers": 4, 27 | "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1}, 28 | } 29 | -------------------------------------------------------------------------------- /runtime/python/grpc/cosyvoice.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | package cosyvoice; 4 | option go_package = "protos/"; 5 | 6 | service CosyVoice{ 7 | rpc Inference(Request) returns (stream Response) {} 8 | } 9 | 10 | message Request{ 11 | oneof RequestPayload { 12 | sftRequest sft_request = 1; 13 | zeroshotRequest zero_shot_request = 2; 14 | crosslingualRequest cross_lingual_request = 3; 15 | instructRequest instruct_request = 4; 16 | } 17 | } 18 | 19 | message sftRequest{ 20 | string spk_id = 1; 21 | string tts_text = 2; 22 | } 23 | 24 | message zeroshotRequest{ 25 | string tts_text = 1; 26 | string prompt_text = 2; 27 | bytes prompt_audio = 3; 28 | } 29 | 30 | message crosslingualRequest{ 31 | string tts_text = 1; 32 | bytes prompt_audio = 2; 33 | } 34 | 35 | message instructRequest{ 36 | string tts_text = 1; 37 | string spk_id = 2; 38 | string instruct_text = 3; 39 | } 40 | 41 | message Response{ 42 | bytes tts_audio = 1; 43 | } -------------------------------------------------------------------------------- /vllm_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('third_party/Matcha-TTS') 3 | from vllm import ModelRegistry 4 | from cosyvoice.vllm.cosyvoice2 import CosyVoice2ForCausalLM 5 | ModelRegistry.register_model("CosyVoice2ForCausalLM", CosyVoice2ForCausalLM) 6 | 7 | from cosyvoice.cli.cosyvoice import CosyVoice2 8 | from cosyvoice.utils.file_utils import load_wav 9 | from cosyvoice.utils.common import set_all_random_seed 10 | from tqdm import tqdm 11 | 12 | 13 | def main(): 14 | cosyvoice = CosyVoice2('pretrained_models/CosyVoice2-0.5B', load_jit=True, load_trt=True, load_vllm=True, fp16=True) 15 | prompt_speech_16k = load_wav('./asset/zero_shot_prompt.wav', 16000) 16 | for i in tqdm(range(100)): 17 | set_all_random_seed(i) 18 | for _, _ in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)): 19 | continue 20 | 21 | 22 | if __name__ == '__main__': 23 | main() 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CosyVoice2-GRPO 2 | 3 | [Demos](https://ryuclc.github.io/LLM-TTS-GRPO/); [Paper](https://arxiv.org/abs/2509.18798) 4 | 5 | A simple implementation for improving CosyVoice2 by GRPO method. 6 | 7 | This is the code of paper “Group Relative Policy Optimization for Text-to-Speech with Large Language Models” 8 | 9 | We modify official CosyVoice2 code with trl to achieve GRPO fine-tune. Only need the CosyVoice2 environment, without need to install any other modules or frameworks. 10 | 11 | Most codes are borrowed from official [CosyVoice2](https://github.com/FunAudioLLM/CosyVoice), https://github.com/channel-io/ch-tts-llasa-rl-grpo, https://github.com/SebastianBodza/blog_projects/tree/main/00_GRPO_LLaSa, https://github.com/huggingface/trl. 12 | 13 | 14 | ### Note 15 | 16 | Now the implementation is rough, plan to refine code or migrate to other framework in the future. 17 | 18 | 19 | ### To-Do 20 | 21 | - [x] paper 22 | - [x] code 23 | - [ ] inference with ckpt 24 | - [ ] training script 25 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | 4 | categories: 5 | - title: "🚀 Features" 6 | labels: 7 | - "feature" 8 | - "enhancement" 9 | - title: "🐛 Bug Fixes" 10 | labels: 11 | - "fix" 12 | - "bugfix" 13 | - "bug" 14 | - title: "🧹 Maintenance" 15 | labels: 16 | - "maintenance" 17 | - "dependencies" 18 | - "refactoring" 19 | - "cosmetic" 20 | - "chore" 21 | - title: "📝️ Documentation" 22 | labels: 23 | - "documentation" 24 | - "docs" 25 | 26 | change-template: "- $TITLE @$AUTHOR (#$NUMBER)" 27 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 28 | 29 | version-resolver: 30 | major: 31 | labels: 32 | - "major" 33 | minor: 34 | labels: 35 | - "minor" 36 | patch: 37 | labels: 38 | - "patch" 39 | default: patch 40 | 41 | template: | 42 | ## Changes 43 | 44 | $CHANGES 45 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/debug/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # default debugging setup, runs 1 full epoch 4 | # other debugging configs can inherit from this one 5 | 6 | # overwrite task name so debugging logs are stored in separate folder 7 | task_name: "debug" 8 | 9 | # disable callbacks and loggers during debugging 10 | # callbacks: null 11 | # logger: null 12 | 13 | extras: 14 | ignore_warnings: False 15 | enforce_tags: False 16 | 17 | # sets level of all command line loggers to 'DEBUG' 18 | # https://hydra.cc/docs/tutorials/basic/running_your_app/logging/ 19 | hydra: 20 | job_logging: 21 | root: 22 | level: DEBUG 23 | 24 | # use this to also set hydra loggers to 'DEBUG' 25 | # verbose: True 26 | 27 | trainer: 28 | max_epochs: 1 29 | accelerator: cpu # debuggers don't like gpus 30 | devices: 1 # debuggers don't like multiprocessing 31 | detect_anomaly: true # raise exception if NaN or +/-inf is detected in any tensor 32 | 33 | data: 34 | num_workers: 0 # debuggers don't like multiprocessing 35 | pin_memory: False # disable gpu memory pin 36 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | lightning>=2.0.0 5 | torchmetrics>=0.11.4 6 | 7 | # --------- hydra --------- # 8 | hydra-core==1.3.2 9 | hydra-colorlog==1.2.0 10 | hydra-optuna-sweeper==1.2.0 11 | 12 | # --------- loggers --------- # 13 | # wandb 14 | # neptune-client 15 | # mlflow 16 | # comet-ml 17 | # aim>=3.16.2 # no lower than 3.16.2, see https://github.com/aimhubio/aim/issues/2550 18 | 19 | # --------- others --------- # 20 | rootutils # standardizing the project root setup 21 | pre-commit # hooks for applying linters on commit 22 | rich # beautiful text formatting in terminal 23 | pytest # tests 24 | # sh # for running bash commands in some tests (linux/macos only) 25 | phonemizer # phonemization of text 26 | tensorboard 27 | librosa 28 | Cython 29 | numpy 30 | einops 31 | inflect 32 | Unidecode 33 | scipy 34 | torchaudio 35 | matplotlib 36 | pandas 37 | conformer==0.3.2 38 | diffusers==0.25.0 39 | notebook 40 | ipywidgets 41 | gradio==3.43.2 42 | gdown 43 | wget 44 | seaborn 45 | piper_phonemize 46 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shivam Mehta 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jungil Kong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel", "cython==0.29.35", "numpy==1.24.3", "packaging"] 3 | 4 | [tool.black] 5 | line-length = 120 6 | target-version = ['py310'] 7 | exclude = ''' 8 | 9 | ( 10 | /( 11 | \.eggs # exclude a few common directories in the 12 | | \.git # root of the project 13 | | \.hg 14 | | \.mypy_cache 15 | | \.tox 16 | | \.venv 17 | | _build 18 | | buck-out 19 | | build 20 | | dist 21 | )/ 22 | | foo.py # also separately exclude a file named foo.py in 23 | # the root of the project 24 | ) 25 | ''' 26 | 27 | [tool.pytest.ini_options] 28 | addopts = [ 29 | "--color=yes", 30 | "--durations=0", 31 | "--strict-markers", 32 | "--doctest-modules", 33 | ] 34 | filterwarnings = [ 35 | "ignore::DeprecationWarning", 36 | "ignore::UserWarning", 37 | ] 38 | log_cli = "True" 39 | markers = [ 40 | "slow: slow tests", 41 | ] 42 | minversion = "6.0" 43 | testpaths = "tests/" 44 | 45 | [tool.coverage.report] 46 | exclude_lines = [ 47 | "pragma: nocover", 48 | "raise NotImplementedError", 49 | "raise NotImplementedError()", 50 | "if __name__ == .__main__.:", 51 | ] 52 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu121 2 | --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ # https://github.com/microsoft/onnxruntime/issues/21684 3 | conformer==0.3.2 4 | deepspeed==0.15.1; sys_platform == 'linux' 5 | diffusers==0.29.0 6 | fastapi==0.115.6 7 | fastapi-cli==0.0.4 8 | gdown==5.1.0 9 | gradio==5.4.0 10 | grpcio==1.57.0 11 | grpcio-tools==1.57.0 12 | hydra-core==1.3.2 13 | HyperPyYAML==1.2.2 14 | inflect==7.3.1 15 | librosa==0.10.2 16 | lightning==2.2.4 17 | matplotlib==3.7.5 18 | modelscope==1.20.0 19 | networkx==3.1 20 | omegaconf==2.3.0 21 | onnx==1.16.0 22 | onnxruntime-gpu==1.18.0; sys_platform == 'linux' 23 | onnxruntime==1.18.0; sys_platform == 'darwin' or sys_platform == 'win32' 24 | openai-whisper==20231117 25 | protobuf==4.25 26 | pyarrow==18.1.0 27 | pydantic==2.7.0 28 | pyworld==0.3.4 29 | rich==13.7.1 30 | soundfile==0.12.1 31 | tensorboard==2.14.0 32 | tensorrt-cu12==10.0.1; sys_platform == 'linux' 33 | tensorrt-cu12-bindings==10.0.1; sys_platform == 'linux' 34 | tensorrt-cu12-libs==10.0.1; sys_platform == 'linux' 35 | torch==2.3.1 36 | torchaudio==2.3.1 37 | transformers==4.40.1 38 | uvicorn==0.30.0 39 | wetext==0.0.4 40 | wget==3.2 41 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/callbacks/model_checkpoint.yaml: -------------------------------------------------------------------------------- 1 | # https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html 2 | 3 | model_checkpoint: 4 | _target_: lightning.pytorch.callbacks.ModelCheckpoint 5 | dirpath: ${paths.output_dir}/checkpoints # directory to save the model file 6 | filename: checkpoint_{epoch:03d} # checkpoint filename 7 | monitor: epoch # name of the logged metric which determines when model is improving 8 | verbose: False # verbosity mode 9 | save_last: true # additionally always save an exact copy of the last checkpoint to a file last.ckpt 10 | save_top_k: 10 # save k best models (determined by above metric) 11 | mode: "max" # "max" means higher metric value is better, can be also "min" 12 | auto_insert_metric_name: True # when True, the checkpoints filenames will contain the metric name 13 | save_weights_only: False # if True, then only the model’s weights will be saved 14 | every_n_train_steps: null # number of training steps between checkpoints 15 | train_time_interval: null # checkpoints are monitored at the specified time interval 16 | every_n_epochs: 100 # number of epochs between checkpoints 17 | save_on_train_epoch_end: null # whether to run checkpointing at the end of the training epoch or the end of validation 18 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/Makefile: -------------------------------------------------------------------------------- 1 | 2 | help: ## Show help 3 | @grep -E '^[.a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 4 | 5 | clean: ## Clean autogenerated files 6 | rm -rf dist 7 | find . -type f -name "*.DS_Store" -ls -delete 8 | find . | grep -E "(__pycache__|\.pyc|\.pyo)" | xargs rm -rf 9 | find . | grep -E ".pytest_cache" | xargs rm -rf 10 | find . | grep -E ".ipynb_checkpoints" | xargs rm -rf 11 | rm -f .coverage 12 | 13 | clean-logs: ## Clean logs 14 | rm -rf logs/** 15 | 16 | create-package: ## Create wheel and tar gz 17 | rm -rf dist/ 18 | python setup.py bdist_wheel --plat-name=manylinux1_x86_64 19 | python setup.py sdist 20 | python -m twine upload dist/* --verbose --skip-existing 21 | 22 | format: ## Run pre-commit hooks 23 | pre-commit run -a 24 | 25 | sync: ## Merge changes from main branch to your current branch 26 | git pull 27 | git pull origin main 28 | 29 | test: ## Run not slow tests 30 | pytest -k "not slow" 31 | 32 | test-full: ## Run all tests 33 | pytest 34 | 35 | train-ljspeech: ## Train the model 36 | python matcha/train.py experiment=ljspeech 37 | 38 | train-ljspeech-min: ## Train the model with minimum memory 39 | python matcha/train.py experiment=ljspeech_min_memory 40 | 41 | start_app: ## Start the app 42 | python matcha/app.py 43 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/logger/aim.yaml: -------------------------------------------------------------------------------- 1 | # https://aimstack.io/ 2 | 3 | # example usage in lightning module: 4 | # https://github.com/aimhubio/aim/blob/main/examples/pytorch_lightning_track.py 5 | 6 | # open the Aim UI with the following command (run in the folder containing the `.aim` folder): 7 | # `aim up` 8 | 9 | aim: 10 | _target_: aim.pytorch_lightning.AimLogger 11 | repo: ${paths.root_dir} # .aim folder will be created here 12 | # repo: "aim://ip_address:port" # can instead provide IP address pointing to Aim remote tracking server which manages the repo, see https://aimstack.readthedocs.io/en/latest/using/remote_tracking.html# 13 | 14 | # aim allows to group runs under experiment name 15 | experiment: null # any string, set to "default" if not specified 16 | 17 | train_metric_prefix: "train/" 18 | val_metric_prefix: "val/" 19 | test_metric_prefix: "test/" 20 | 21 | # sets the tracking interval in seconds for system usage metrics (CPU, GPU, memory, etc.) 22 | system_tracking_interval: 10 # set to null to disable system metrics tracking 23 | 24 | # enable/disable logging of system params such as installed packages, git info, env vars, etc. 25 | log_system_params: true 26 | 27 | # enable/disable tracking console logs (default value is true) 28 | capture_terminal_logs: false # set to false to avoid infinite console log loop issue https://github.com/aimhubio/aim/issues/2550 29 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/monotonic_align/core.pyx: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | cimport cython 4 | cimport numpy as np 5 | 6 | from cython.parallel import prange 7 | 8 | 9 | @cython.boundscheck(False) 10 | @cython.wraparound(False) 11 | cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: 12 | cdef int x 13 | cdef int y 14 | cdef float v_prev 15 | cdef float v_cur 16 | cdef float tmp 17 | cdef int index = t_x - 1 18 | 19 | for y in range(t_y): 20 | for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): 21 | if x == y: 22 | v_cur = max_neg_val 23 | else: 24 | v_cur = value[x, y-1] 25 | if x == 0: 26 | if y == 0: 27 | v_prev = 0. 28 | else: 29 | v_prev = max_neg_val 30 | else: 31 | v_prev = value[x-1, y-1] 32 | value[x, y] = max(v_cur, v_prev) + value[x, y] 33 | 34 | for y in range(t_y - 1, -1, -1): 35 | path[index, y] = 1 36 | if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): 37 | index = index - 1 38 | 39 | 40 | @cython.boundscheck(False) 41 | @cython.wraparound(False) 42 | cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: 43 | cdef int b = values.shape[0] 44 | 45 | cdef int i 46 | for i in prange(b, nogil=True): 47 | maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) 48 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/xutils.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jik876/hifi-gan """ 2 | 3 | import glob 4 | import os 5 | 6 | import matplotlib 7 | import torch 8 | from torch.nn.utils import weight_norm 9 | 10 | matplotlib.use("Agg") 11 | import matplotlib.pylab as plt 12 | 13 | 14 | def plot_spectrogram(spectrogram): 15 | fig, ax = plt.subplots(figsize=(10, 2)) 16 | im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none") 17 | plt.colorbar(im, ax=ax) 18 | 19 | fig.canvas.draw() 20 | plt.close() 21 | 22 | return fig 23 | 24 | 25 | def init_weights(m, mean=0.0, std=0.01): 26 | classname = m.__class__.__name__ 27 | if classname.find("Conv") != -1: 28 | m.weight.data.normal_(mean, std) 29 | 30 | 31 | def apply_weight_norm(m): 32 | classname = m.__class__.__name__ 33 | if classname.find("Conv") != -1: 34 | weight_norm(m) 35 | 36 | 37 | def get_padding(kernel_size, dilation=1): 38 | return int((kernel_size * dilation - dilation) / 2) 39 | 40 | 41 | def load_checkpoint(filepath, device): 42 | assert os.path.isfile(filepath) 43 | print(f"Loading '{filepath}'") 44 | checkpoint_dict = torch.load(filepath, map_location=device) 45 | print("Complete.") 46 | return checkpoint_dict 47 | 48 | 49 | def save_checkpoint(filepath, obj): 50 | print(f"Saving checkpoint to {filepath}") 51 | torch.save(obj, filepath) 52 | print("Complete.") 53 | 54 | 55 | def scan_checkpoint(cp_dir, prefix): 56 | pattern = os.path.join(cp_dir, prefix + "????????") 57 | cp_list = glob.glob(pattern) 58 | if len(cp_list) == 0: 59 | return None 60 | return sorted(cp_list)[-1] 61 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | from Cython.Build import cythonize 6 | from setuptools import Extension, find_packages, setup 7 | 8 | exts = [ 9 | Extension( 10 | name="matcha.utils.monotonic_align.core", 11 | sources=["matcha/utils/monotonic_align/core.pyx"], 12 | ) 13 | ] 14 | 15 | with open("README.md", encoding="utf-8") as readme_file: 16 | README = readme_file.read() 17 | 18 | cwd = os.path.dirname(os.path.abspath(__file__)) 19 | with open(os.path.join(cwd, "matcha", "VERSION")) as fin: 20 | version = fin.read().strip() 21 | 22 | setup( 23 | name="matcha-tts", 24 | version=version, 25 | description="🍵 Matcha-TTS: A fast TTS architecture with conditional flow matching", 26 | long_description=README, 27 | long_description_content_type="text/markdown", 28 | author="Shivam Mehta", 29 | author_email="shivam.mehta25@gmail.com", 30 | url="https://shivammehta25.github.io/Matcha-TTS", 31 | install_requires=[str(r) for r in open(os.path.join(os.path.dirname(__file__), "requirements.txt"))], 32 | include_dirs=[numpy.get_include()], 33 | include_package_data=True, 34 | packages=find_packages(exclude=["tests", "tests/*", "examples", "examples/*"]), 35 | # use this to customize global commands available in the terminal after installing the package 36 | entry_points={ 37 | "console_scripts": [ 38 | "matcha-data-stats=matcha.utils.generate_data_statistics:main", 39 | "matcha-tts=matcha.cli:cli", 40 | "matcha-tts-app=matcha.app:main", 41 | ] 42 | }, 43 | ext_modules=cythonize(exts, language_level=3), 44 | python_requires=">=3.9.0", 45 | ) 46 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.10 3 | 4 | repos: 5 | - repo: https://github.com/pre-commit/pre-commit-hooks 6 | rev: v4.5.0 7 | hooks: 8 | # list of supported hooks: https://pre-commit.com/hooks.html 9 | - id: trailing-whitespace 10 | - id: end-of-file-fixer 11 | # - id: check-docstring-first 12 | - id: check-yaml 13 | - id: debug-statements 14 | - id: detect-private-key 15 | - id: check-toml 16 | - id: check-case-conflict 17 | - id: check-added-large-files 18 | 19 | # python code formatting 20 | - repo: https://github.com/psf/black 21 | rev: 23.12.1 22 | hooks: 23 | - id: black 24 | args: [--line-length, "120"] 25 | 26 | # python import sorting 27 | - repo: https://github.com/PyCQA/isort 28 | rev: 5.13.2 29 | hooks: 30 | - id: isort 31 | args: ["--profile", "black", "--filter-files"] 32 | 33 | # python upgrading syntax to newer version 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.15.0 36 | hooks: 37 | - id: pyupgrade 38 | args: [--py38-plus] 39 | 40 | # python check (PEP8), programming errors and code complexity 41 | - repo: https://github.com/PyCQA/flake8 42 | rev: 7.0.0 43 | hooks: 44 | - id: flake8 45 | args: 46 | [ 47 | "--max-line-length", "120", 48 | "--extend-ignore", 49 | "E203,E402,E501,F401,F841,RST2,RST301", 50 | "--exclude", 51 | "logs/*,data/*,matcha/hifigan/*", 52 | ] 53 | additional_dependencies: [flake8-rst-docstrings==0.3.0] 54 | 55 | # pylint 56 | - repo: https://github.com/pycqa/pylint 57 | rev: v3.0.3 58 | hooks: 59 | - id: pylint 60 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # specify here default configuration 4 | # order of defaults determines the order in which configs override each other 5 | defaults: 6 | - _self_ 7 | - data: ljspeech 8 | - model: matcha 9 | - callbacks: default 10 | - logger: tensorboard # set logger here or use command line (e.g. `python train.py logger=tensorboard`) 11 | - trainer: default 12 | - paths: default 13 | - extras: default 14 | - hydra: default 15 | 16 | # experiment configs allow for version control of specific hyperparameters 17 | # e.g. best hyperparameters for given model and datamodule 18 | - experiment: null 19 | 20 | # config for hyperparameter optimization 21 | - hparams_search: null 22 | 23 | # optional local config for machine/user specific settings 24 | # it's optional since it doesn't need to exist and is excluded from version control 25 | - optional local: default 26 | 27 | # debugging config (enable through command line, e.g. `python train.py debug=default) 28 | - debug: null 29 | 30 | # task name, determines output directory path 31 | task_name: "train" 32 | 33 | run_name: ??? 34 | 35 | # tags to help you identify your experiments 36 | # you can overwrite this in experiment configs 37 | # overwrite from command line with `python train.py tags="[first_tag, second_tag]"` 38 | tags: ["dev"] 39 | 40 | # set False to skip model training 41 | train: True 42 | 43 | # evaluate on test set, using best model weights achieved during training 44 | # lightning chooses best weights based on the metric specified in checkpoint callback 45 | test: True 46 | 47 | # simply provide checkpoint path to resume training 48 | ckpt_path: null 49 | 50 | # seed for random number generators in pytorch, numpy and python.random 51 | seed: 1234 52 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/__init__.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | from matcha.text import cleaners 3 | from matcha.text.symbols import symbols 4 | 5 | # Mappings from symbol to numeric ID and vice versa: 6 | _symbol_to_id = {s: i for i, s in enumerate(symbols)} 7 | _id_to_symbol = {i: s for i, s in enumerate(symbols)} # pylint: disable=unnecessary-comprehension 8 | 9 | 10 | def text_to_sequence(text, cleaner_names): 11 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 12 | Args: 13 | text: string to convert to a sequence 14 | cleaner_names: names of the cleaner functions to run the text through 15 | Returns: 16 | List of integers corresponding to the symbols in the text 17 | """ 18 | sequence = [] 19 | 20 | clean_text = _clean_text(text, cleaner_names) 21 | for symbol in clean_text: 22 | symbol_id = _symbol_to_id[symbol] 23 | sequence += [symbol_id] 24 | return sequence 25 | 26 | 27 | def cleaned_text_to_sequence(cleaned_text): 28 | """Converts a string of text to a sequence of IDs corresponding to the symbols in the text. 29 | Args: 30 | text: string to convert to a sequence 31 | Returns: 32 | List of integers corresponding to the symbols in the text 33 | """ 34 | sequence = [_symbol_to_id[symbol] for symbol in cleaned_text] 35 | return sequence 36 | 37 | 38 | def sequence_to_text(sequence): 39 | """Converts a sequence of IDs back to a string""" 40 | result = "" 41 | for symbol_id in sequence: 42 | s = _id_to_symbol[symbol_id] 43 | result += s 44 | return result 45 | 46 | 47 | def _clean_text(text, cleaner_names): 48 | for name in cleaner_names: 49 | cleaner = getattr(cleaners, name) 50 | if not cleaner: 51 | raise Exception("Unknown cleaner: %s" % name) 52 | text = cleaner(text) 53 | return text 54 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | from lightning.pytorch.utilities import rank_zero_only 4 | from omegaconf import OmegaConf 5 | 6 | from matcha.utils import pylogger 7 | 8 | log = pylogger.get_pylogger(__name__) 9 | 10 | 11 | @rank_zero_only 12 | def log_hyperparameters(object_dict: Dict[str, Any]) -> None: 13 | """Controls which config parts are saved by Lightning loggers. 14 | 15 | Additionally saves: 16 | - Number of model parameters 17 | 18 | :param object_dict: A dictionary containing the following objects: 19 | - `"cfg"`: A DictConfig object containing the main config. 20 | - `"model"`: The Lightning model. 21 | - `"trainer"`: The Lightning trainer. 22 | """ 23 | hparams = {} 24 | 25 | cfg = OmegaConf.to_container(object_dict["cfg"]) 26 | model = object_dict["model"] 27 | trainer = object_dict["trainer"] 28 | 29 | if not trainer.logger: 30 | log.warning("Logger not found! Skipping hyperparameter logging...") 31 | return 32 | 33 | hparams["model"] = cfg["model"] 34 | 35 | # save number of model parameters 36 | hparams["model/params/total"] = sum(p.numel() for p in model.parameters()) 37 | hparams["model/params/trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) 38 | hparams["model/params/non_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) 39 | 40 | hparams["data"] = cfg["data"] 41 | hparams["trainer"] = cfg["trainer"] 42 | 43 | hparams["callbacks"] = cfg.get("callbacks") 44 | hparams["extras"] = cfg.get("extras") 45 | 46 | hparams["task_name"] = cfg.get("task_name") 47 | hparams["tags"] = cfg.get("tags") 48 | hparams["ckpt_path"] = cfg.get("ckpt_path") 49 | hparams["seed"] = cfg.get("seed") 50 | 51 | # send hparams to all loggers 52 | for logger in trainer.loggers: 53 | logger.log_hyperparams(hparams) 54 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/configs/hparams_search/mnist_optuna.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # example hyperparameter optimization of some experiment with Optuna: 4 | # python train.py -m hparams_search=mnist_optuna experiment=example 5 | 6 | defaults: 7 | - override /hydra/sweeper: optuna 8 | 9 | # choose metric which will be optimized by Optuna 10 | # make sure this is the correct name of some metric logged in lightning module! 11 | optimized_metric: "val/acc_best" 12 | 13 | # here we define Optuna hyperparameter search 14 | # it optimizes for value returned from function with @hydra.main decorator 15 | # docs: https://hydra.cc/docs/next/plugins/optuna_sweeper 16 | hydra: 17 | mode: "MULTIRUN" # set hydra to multirun by default if this config is attached 18 | 19 | sweeper: 20 | _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper 21 | 22 | # storage URL to persist optimization results 23 | # for example, you can use SQLite if you set 'sqlite:///example.db' 24 | storage: null 25 | 26 | # name of the study to persist optimization results 27 | study_name: null 28 | 29 | # number of parallel workers 30 | n_jobs: 1 31 | 32 | # 'minimize' or 'maximize' the objective 33 | direction: maximize 34 | 35 | # total number of runs that will be executed 36 | n_trials: 20 37 | 38 | # choose Optuna hyperparameter sampler 39 | # you can choose bayesian sampler (tpe), random search (without optimization), grid sampler, and others 40 | # docs: https://optuna.readthedocs.io/en/stable/reference/samplers.html 41 | sampler: 42 | _target_: optuna.samplers.TPESampler 43 | seed: 1234 44 | n_startup_trials: 10 # number of random sampling runs before optimization starts 45 | 46 | # define hyperparameter search space 47 | params: 48 | model.optimizer.lr: interval(0.0001, 0.1) 49 | data.batch_size: choice(32, 64, 128, 256) 50 | model.net.lin1_size: choice(64, 128, 256) 51 | model.net.lin2_size: choice(64, 128, 256) 52 | model.net.lin3_size: choice(32, 64, 128, 256) 53 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/instantiators.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import hydra 4 | from lightning import Callback 5 | from lightning.pytorch.loggers import Logger 6 | from omegaconf import DictConfig 7 | 8 | from matcha.utils import pylogger 9 | 10 | log = pylogger.get_pylogger(__name__) 11 | 12 | 13 | def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]: 14 | """Instantiates callbacks from config. 15 | 16 | :param callbacks_cfg: A DictConfig object containing callback configurations. 17 | :return: A list of instantiated callbacks. 18 | """ 19 | callbacks: List[Callback] = [] 20 | 21 | if not callbacks_cfg: 22 | log.warning("No callback configs found! Skipping..") 23 | return callbacks 24 | 25 | if not isinstance(callbacks_cfg, DictConfig): 26 | raise TypeError("Callbacks config must be a DictConfig!") 27 | 28 | for _, cb_conf in callbacks_cfg.items(): 29 | if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf: 30 | log.info(f"Instantiating callback <{cb_conf._target_}>") # pylint: disable=protected-access 31 | callbacks.append(hydra.utils.instantiate(cb_conf)) 32 | 33 | return callbacks 34 | 35 | 36 | def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]: 37 | """Instantiates loggers from config. 38 | 39 | :param logger_cfg: A DictConfig object containing logger configurations. 40 | :return: A list of instantiated loggers. 41 | """ 42 | logger: List[Logger] = [] 43 | 44 | if not logger_cfg: 45 | log.warning("No logger configs found! Skipping...") 46 | return logger 47 | 48 | if not isinstance(logger_cfg, DictConfig): 49 | raise TypeError("Logger config must be a DictConfig!") 50 | 51 | for _, lg_conf in logger_cfg.items(): 52 | if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf: 53 | log.info(f"Instantiating logger <{lg_conf._target_}>") # pylint: disable=protected-access 54 | logger.append(hydra.utils.instantiate(lg_conf)) 55 | 56 | return logger 57 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04 2 | 3 | ARG VENV_NAME="cosyvoice" 4 | ENV VENV=$VENV_NAME 5 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 6 | 7 | ENV DEBIAN_FRONTEN=noninteractive 8 | ENV PYTHONUNBUFFERED=1 9 | SHELL ["/bin/bash", "--login", "-c"] 10 | 11 | RUN apt-get update -y --fix-missing 12 | RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \ 13 | apt-get clean && \ 14 | git lfs install 15 | 16 | # ================================================================== 17 | # conda install and conda forge channel as default 18 | # ------------------------------------------------------------------ 19 | # Install miniforge 20 | RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \ 21 | /bin/bash ~/miniforge.sh -b -p /opt/conda && \ 22 | rm ~/miniforge.sh && \ 23 | ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \ 24 | echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \ 25 | echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \ 26 | echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \ 27 | echo "conda activate ${VENV}" >> $HOME/.bashrc 28 | 29 | ENV PATH /opt/conda/bin:$PATH 30 | 31 | RUN conda config --add channels conda-forge && \ 32 | conda config --set channel_priority strict 33 | # ------------------------------------------------------------------ 34 | # ~conda 35 | # ================================================================== 36 | 37 | RUN conda create -y -n ${VENV} python=3.10 38 | ENV CONDA_DEFAULT_ENV=${VENV} 39 | ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH 40 | 41 | WORKDIR /workspace 42 | 43 | ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS" 44 | 45 | RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git 46 | 47 | RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5 48 | RUN conda activate ${VENV} && cd CosyVoice && \ 49 | pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com 50 | 51 | WORKDIR /workspace/CosyVoice 52 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/f0_predictor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | try: 17 | from torch.nn.utils.parametrizations import weight_norm 18 | except ImportError: 19 | from torch.nn.utils import weight_norm 20 | 21 | 22 | class ConvRNNF0Predictor(nn.Module): 23 | def __init__(self, 24 | num_class: int = 1, 25 | in_channels: int = 80, 26 | cond_channels: int = 512 27 | ): 28 | super().__init__() 29 | 30 | self.num_class = num_class 31 | self.condnet = nn.Sequential( 32 | weight_norm( 33 | nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1) 34 | ), 35 | nn.ELU(), 36 | weight_norm( 37 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 38 | ), 39 | nn.ELU(), 40 | weight_norm( 41 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 42 | ), 43 | nn.ELU(), 44 | weight_norm( 45 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 46 | ), 47 | nn.ELU(), 48 | weight_norm( 49 | nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1) 50 | ), 51 | nn.ELU(), 52 | ) 53 | self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class) 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | x = self.condnet(x) 57 | x = x.transpose(1, 2) 58 | return torch.abs(self.classifier(x).squeeze(-1)) 59 | -------------------------------------------------------------------------------- /cosyvoice/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Tuple 4 | 5 | 6 | def tpr_loss(disc_real_outputs, disc_generated_outputs, tau): 7 | loss = 0 8 | for dr, dg in zip(disc_real_outputs, disc_generated_outputs): 9 | m_DG = torch.median((dr - dg)) 10 | L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG]) 11 | loss += tau - F.relu(tau - L_rel) 12 | return loss 13 | 14 | 15 | def mel_loss(real_speech, generated_speech, mel_transforms): 16 | loss = 0 17 | for transform in mel_transforms: 18 | mel_r = transform(real_speech) 19 | mel_g = transform(generated_speech) 20 | loss += F.l1_loss(mel_g, mel_r) 21 | return loss 22 | 23 | 24 | class DPOLoss(torch.nn.Module): 25 | """ 26 | DPO Loss 27 | """ 28 | 29 | def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None: 30 | super().__init__() 31 | self.beta = beta 32 | self.label_smoothing = label_smoothing 33 | self.ipo = ipo 34 | 35 | def forward( 36 | self, 37 | policy_chosen_logps: torch.Tensor, 38 | policy_rejected_logps: torch.Tensor, 39 | reference_chosen_logps: torch.Tensor, 40 | reference_rejected_logps: torch.Tensor, 41 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 42 | pi_logratios = policy_chosen_logps - policy_rejected_logps 43 | ref_logratios = reference_chosen_logps - reference_rejected_logps 44 | logits = pi_logratios - ref_logratios 45 | if self.ipo: 46 | losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf 47 | else: 48 | # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf) 49 | losses = ( 50 | -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) 51 | - F.logsigmoid(-self.beta * logits) * self.label_smoothing 52 | ) 53 | loss = losses.mean() 54 | chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach() 55 | rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach() 56 | 57 | return loss, chosen_rewards, rejected_rewards 58 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/numbers.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron """ 2 | 3 | import re 4 | 5 | import inflect 6 | 7 | _inflect = inflect.engine() 8 | _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])") 9 | _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)") 10 | _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)") 11 | _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)") 12 | _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)") 13 | _number_re = re.compile(r"[0-9]+") 14 | 15 | 16 | def _remove_commas(m): 17 | return m.group(1).replace(",", "") 18 | 19 | 20 | def _expand_decimal_point(m): 21 | return m.group(1).replace(".", " point ") 22 | 23 | 24 | def _expand_dollars(m): 25 | match = m.group(1) 26 | parts = match.split(".") 27 | if len(parts) > 2: 28 | return match + " dollars" 29 | dollars = int(parts[0]) if parts[0] else 0 30 | cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 31 | if dollars and cents: 32 | dollar_unit = "dollar" if dollars == 1 else "dollars" 33 | cent_unit = "cent" if cents == 1 else "cents" 34 | return f"{dollars} {dollar_unit}, {cents} {cent_unit}" 35 | elif dollars: 36 | dollar_unit = "dollar" if dollars == 1 else "dollars" 37 | return f"{dollars} {dollar_unit}" 38 | elif cents: 39 | cent_unit = "cent" if cents == 1 else "cents" 40 | return f"{cents} {cent_unit}" 41 | else: 42 | return "zero dollars" 43 | 44 | 45 | def _expand_ordinal(m): 46 | return _inflect.number_to_words(m.group(0)) 47 | 48 | 49 | def _expand_number(m): 50 | num = int(m.group(0)) 51 | if num > 1000 and num < 3000: 52 | if num == 2000: 53 | return "two thousand" 54 | elif num > 2000 and num < 2010: 55 | return "two thousand " + _inflect.number_to_words(num % 100) 56 | elif num % 100 == 0: 57 | return _inflect.number_to_words(num // 100) + " hundred" 58 | else: 59 | return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ") 60 | else: 61 | return _inflect.number_to_words(num, andword="") 62 | 63 | 64 | def normalize_numbers(text): 65 | text = re.sub(_comma_number_re, _remove_commas, text) 66 | text = re.sub(_pounds_re, r"\1 pounds", text) 67 | text = re.sub(_dollars_re, _expand_dollars, text) 68 | text = re.sub(_decimal_number_re, _expand_decimal_point, text) 69 | text = re.sub(_ordinal_re, _expand_ordinal, text) 70 | text = re.sub(_number_re, _expand_number, text) 71 | return text 72 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/audio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data 4 | from librosa.filters import mel as librosa_mel_fn 5 | from scipy.io.wavfile import read 6 | 7 | MAX_WAV_VALUE = 32768.0 8 | 9 | 10 | def load_wav(full_path): 11 | sampling_rate, data = read(full_path) 12 | return data, sampling_rate 13 | 14 | 15 | def dynamic_range_compression(x, C=1, clip_val=1e-5): 16 | return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) 17 | 18 | 19 | def dynamic_range_decompression(x, C=1): 20 | return np.exp(x) / C 21 | 22 | 23 | def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): 24 | return torch.log(torch.clamp(x, min=clip_val) * C) 25 | 26 | 27 | def dynamic_range_decompression_torch(x, C=1): 28 | return torch.exp(x) / C 29 | 30 | 31 | def spectral_normalize_torch(magnitudes): 32 | output = dynamic_range_compression_torch(magnitudes) 33 | return output 34 | 35 | 36 | def spectral_de_normalize_torch(magnitudes): 37 | output = dynamic_range_decompression_torch(magnitudes) 38 | return output 39 | 40 | 41 | mel_basis = {} 42 | hann_window = {} 43 | 44 | 45 | def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): 46 | if torch.min(y) < -1.0: 47 | print("min value is ", torch.min(y)) 48 | if torch.max(y) > 1.0: 49 | print("max value is ", torch.max(y)) 50 | 51 | global mel_basis, hann_window # pylint: disable=global-statement 52 | if f"{str(fmax)}_{str(y.device)}" not in mel_basis: 53 | mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) 54 | mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device) 55 | hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) 56 | 57 | y = torch.nn.functional.pad( 58 | y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect" 59 | ) 60 | y = y.squeeze(1) 61 | 62 | spec = torch.view_as_real( 63 | torch.stft( 64 | y, 65 | n_fft, 66 | hop_length=hop_size, 67 | win_length=win_size, 68 | window=hann_window[str(y.device)], 69 | center=center, 70 | pad_mode="reflect", 71 | normalized=False, 72 | onesided=True, 73 | return_complex=True, 74 | ) 75 | ) 76 | 77 | spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) 78 | 79 | spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec) 80 | spec = spectral_normalize_torch(spec) 81 | 82 | return spec 83 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/denoiser.py: -------------------------------------------------------------------------------- 1 | # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py 2 | 3 | """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio.""" 4 | import torch 5 | 6 | 7 | class Denoiser(torch.nn.Module): 8 | """Removes model bias from audio produced with waveglow""" 9 | 10 | def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"): 11 | super().__init__() 12 | self.filter_length = filter_length 13 | self.hop_length = int(filter_length / n_overlap) 14 | self.win_length = win_length 15 | 16 | dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device 17 | self.device = device 18 | if mode == "zeros": 19 | mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device) 20 | elif mode == "normal": 21 | mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device) 22 | else: 23 | raise Exception(f"Mode {mode} if not supported") 24 | 25 | def stft_fn(audio, n_fft, hop_length, win_length, window): 26 | spec = torch.stft( 27 | audio, 28 | n_fft=n_fft, 29 | hop_length=hop_length, 30 | win_length=win_length, 31 | window=window, 32 | return_complex=True, 33 | ) 34 | spec = torch.view_as_real(spec) 35 | return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0]) 36 | 37 | self.stft = lambda x: stft_fn( 38 | audio=x, 39 | n_fft=self.filter_length, 40 | hop_length=self.hop_length, 41 | win_length=self.win_length, 42 | window=torch.hann_window(self.win_length, device=device), 43 | ) 44 | self.istft = lambda x, y: torch.istft( 45 | torch.complex(x * torch.cos(y), x * torch.sin(y)), 46 | n_fft=self.filter_length, 47 | hop_length=self.hop_length, 48 | win_length=self.win_length, 49 | window=torch.hann_window(self.win_length, device=device), 50 | ) 51 | 52 | with torch.no_grad(): 53 | bias_audio = vocoder(mel_input).float().squeeze(0) 54 | bias_spec, _ = self.stft(bias_audio) 55 | 56 | self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None]) 57 | 58 | @torch.inference_mode() 59 | def forward(self, audio, strength=0.0005): 60 | audio_spec, audio_angles = self.stft(audio) 61 | audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength 62 | audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) 63 | audio_denoised = self.istft(audio_spec_denoised, audio_angles) 64 | return audio_denoised 65 | -------------------------------------------------------------------------------- /tools/extract_speech_token.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import logging 18 | import torch 19 | from tqdm import tqdm 20 | import onnxruntime 21 | import numpy as np 22 | import torchaudio 23 | import whisper 24 | 25 | 26 | def single_job(utt): 27 | audio, sample_rate = torchaudio.load(utt2wav[utt], backend='soundfile') 28 | if sample_rate != 16000: 29 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 30 | # Convert audio to mono 31 | if audio.shape[0] > 1: 32 | audio = audio.mean(dim=0, keepdim=True) 33 | if audio.shape[1] / 16000 > 30: 34 | logging.warning('do not support extract speech token for audio longer than 30s') 35 | speech_token = [] 36 | else: 37 | feat = whisper.log_mel_spectrogram(audio, n_mels=128) 38 | speech_token = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.detach().cpu().numpy(), 39 | ort_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist() 40 | return utt, speech_token 41 | 42 | 43 | def main(args): 44 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 45 | utt2speech_token = {} 46 | for future in tqdm(as_completed(all_task)): 47 | utt, speech_token = future.result() 48 | utt2speech_token[utt] = speech_token 49 | torch.save(utt2speech_token, '{}/utt2speech_token.pt'.format(args.dir)) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--dir", type=str) 55 | parser.add_argument("--onnx_path", type=str) 56 | parser.add_argument("--num_thread", type=int, default=8) 57 | args = parser.parse_args() 58 | 59 | utt2wav = {} 60 | with open('{}/wav.scp'.format(args.dir)) as f: 61 | for l in f: 62 | l = l.replace('\n', '').split() 63 | utt2wav[l[0]] = l[1] 64 | 65 | option = onnxruntime.SessionOptions() 66 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 67 | option.intra_op_num_threads = 1 68 | providers = ["CUDAExecutionProvider"] 69 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 70 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 71 | 72 | main(args) 73 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .venv 106 | env/ 107 | venv/ 108 | ENV/ 109 | env.bak/ 110 | venv.bak/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | .dmypy.json 125 | dmypy.json 126 | 127 | # Pyre type checker 128 | .pyre/ 129 | 130 | ### VisualStudioCode 131 | .vscode/* 132 | !.vscode/settings.json 133 | !.vscode/tasks.json 134 | !.vscode/launch.json 135 | !.vscode/extensions.json 136 | *.code-workspace 137 | **/.vscode 138 | 139 | # JetBrains 140 | .idea/ 141 | 142 | # Data & Models 143 | *.h5 144 | *.tar 145 | *.tar.gz 146 | 147 | # Lightning-Hydra-Template 148 | configs/local/default.yaml 149 | /data/ 150 | /logs/ 151 | .env 152 | 153 | # Aim logging 154 | .aim 155 | 156 | # Cython complied files 157 | matcha/utils/monotonic_align/core.c 158 | 159 | # Ignoring hifigan checkpoint 160 | generator_v1 161 | g_02500000 162 | gradio_cached_examples/ 163 | synth_output/ 164 | -------------------------------------------------------------------------------- /tools/collect_spk_embedding_fast.py: -------------------------------------------------------------------------------- 1 | # This script dumps waveforms to wav-copy format wav ark, including sample rate and int16 sequence. 2 | """ 3 | Author: Zhihao Du 4 | Date: 2023.03.29 5 | Description: Dump wav, flac and ark files to wav ark files with given sampling rate. 6 | """ 7 | import logging 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | import os 11 | import time 12 | import argparse 13 | import numpy as np 14 | import kaldiio 15 | # from tqdm import tqdm 16 | 17 | 18 | def main(args): 19 | logger = logging.getLogger() 20 | logger.setLevel(logging.INFO) 21 | console_handler = logging.StreamHandler() 22 | console_handler.setLevel(logging.INFO) 23 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 24 | console_handler.setFormatter(formatter) 25 | logger.addHandler(console_handler) 26 | rank = int(os.environ['LOCAL_RANK']) 27 | threads_num = int(os.environ['WORLD_SIZE']) 28 | out_dir = os.path.join(args.dir, 'spk_embedding') 29 | logger.info("rank {}/{}: out_dir {}.".format( 30 | rank, threads_num, out_dir 31 | )) 32 | if out_dir is not None: 33 | if rank == 0: 34 | if not os.path.exists(out_dir): 35 | os.makedirs(out_dir) 36 | else: 37 | while not os.path.exists(out_dir): 38 | time.sleep(0.5) 39 | 40 | utt2wav, spk2utt = {}, {} 41 | with open('{}/embedding.scp'.format(args.dir)) as f: 42 | for l in f: 43 | l = l.replace('\n', '').split() 44 | utt2wav[l[0]] = l[1] 45 | with open('{}/spk2utt'.format(args.dir)) as f: 46 | for l in f: 47 | l = l.replace('\n', '').split() 48 | spk2utt[l[0]] = l[1:] 49 | 50 | all_recs = list(spk2utt.keys()) 51 | local_all_recs = all_recs[rank::threads_num] 52 | 53 | out_path = os.path.join(out_dir, f"spk_embedding.{rank:02d}") 54 | wav_writer = kaldiio.WriteHelper(f"ark,scp,f:{out_path}.ark,{out_path}.scp") 55 | out_path = os.path.join(out_dir, f"length.{rank:02d}.txt") 56 | length_writer = open(out_path, "wt") 57 | meeting_count = 0 58 | for i, spk in enumerate(local_all_recs): 59 | 60 | spk2embedding = [] 61 | for utt in spk2utt[spk]: 62 | if utt not in utt2wav: 63 | continue 64 | spk2embedding.append(kaldiio.load_mat(utt2wav[utt])) 65 | 66 | if len(spk2embedding)==0: 67 | continue 68 | 69 | mean_embedding = np.array(spk2embedding).mean(axis=0) 70 | wav_writer(spk, mean_embedding.astype(np.float32)) 71 | 72 | length_writer.write("{} {}\n".format(spk, len(spk2embedding))) 73 | 74 | if i % 100 == 0: 75 | logger.info("{}/{}: process {}.".format(rank, threads_num, spk)) 76 | 77 | meeting_count += 1 78 | 79 | wav_writer.close() 80 | length_writer.close() 81 | logger.info("{}/{}: Complete {} records.".format(rank, threads_num, meeting_count)) 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument("--dir", type=str) 87 | args = parser.parse_args() 88 | main(args) 89 | -------------------------------------------------------------------------------- /tools/extract_embedding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | from concurrent.futures import ThreadPoolExecutor, as_completed 17 | import onnxruntime 18 | import torch 19 | import torchaudio 20 | import torchaudio.compliance.kaldi as kaldi 21 | from tqdm import tqdm 22 | 23 | 24 | def single_job(utt): 25 | audio, sample_rate = torchaudio.load(utt2wav[utt]) 26 | if sample_rate != 16000: 27 | audio = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(audio) 28 | feat = kaldi.fbank(audio, 29 | num_mel_bins=80, 30 | dither=0, 31 | sample_frequency=16000) 32 | feat = feat - feat.mean(dim=0, keepdim=True) 33 | embedding = ort_session.run(None, {ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist() 34 | return utt, embedding 35 | 36 | 37 | def main(args): 38 | all_task = [executor.submit(single_job, utt) for utt in utt2wav.keys()] 39 | utt2embedding, spk2embedding = {}, {} 40 | for future in tqdm(as_completed(all_task)): 41 | utt, embedding = future.result() 42 | utt2embedding[utt] = embedding 43 | spk = utt2spk[utt] 44 | if spk not in spk2embedding: 45 | spk2embedding[spk] = [] 46 | spk2embedding[spk].append(embedding) 47 | for k, v in spk2embedding.items(): 48 | spk2embedding[k] = torch.tensor(v).mean(dim=0).tolist() 49 | torch.save(utt2embedding, "{}/utt2embedding.pt".format(args.dir)) 50 | torch.save(spk2embedding, "{}/spk2embedding.pt".format(args.dir)) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--dir", type=str) 56 | parser.add_argument("--onnx_path", type=str) 57 | parser.add_argument("--num_thread", type=int, default=8) 58 | args = parser.parse_args() 59 | 60 | utt2wav, utt2spk = {}, {} 61 | with open('{}/wav.scp'.format(args.dir)) as f: 62 | for l in f: 63 | l = l.replace('\n', '').split() 64 | utt2wav[l[0]] = l[1] 65 | with open('{}/utt2spk'.format(args.dir)) as f: 66 | for l in f: 67 | l = l.replace('\n', '').split() 68 | utt2spk[l[0]] = l[1] 69 | 70 | option = onnxruntime.SessionOptions() 71 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 72 | option.intra_op_num_threads = 1 73 | providers = ["CPUExecutionProvider"] 74 | ort_session = onnxruntime.InferenceSession(args.onnx_path, sess_options=option, providers=providers) 75 | executor = ThreadPoolExecutor(max_workers=args.num_thread) 76 | 77 | main(args) 78 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/model.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/jaywalnut310/glow-tts """ 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def sequence_mask(length, max_length=None): 8 | if max_length is None: 9 | max_length = length.max() 10 | x = torch.arange(max_length, dtype=length.dtype, device=length.device) 11 | return x.unsqueeze(0) < length.unsqueeze(1) 12 | 13 | 14 | def fix_len_compatibility(length, num_downsamplings_in_unet=2): 15 | factor = torch.scalar_tensor(2).pow(num_downsamplings_in_unet) 16 | length = (length / factor).ceil() * factor 17 | if not torch.onnx.is_in_onnx_export(): 18 | return length.int().item() 19 | else: 20 | return length 21 | 22 | 23 | def convert_pad_shape(pad_shape): 24 | inverted_shape = pad_shape[::-1] 25 | pad_shape = [item for sublist in inverted_shape for item in sublist] 26 | return pad_shape 27 | 28 | 29 | def generate_path(duration, mask): 30 | device = duration.device 31 | 32 | b, t_x, t_y = mask.shape 33 | cum_duration = torch.cumsum(duration, 1) 34 | path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device) 35 | 36 | cum_duration_flat = cum_duration.view(b * t_x) 37 | path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) 38 | path = path.view(b, t_x, t_y) 39 | path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] 40 | path = path * mask 41 | return path 42 | 43 | 44 | def duration_loss(logw, logw_, lengths): 45 | loss = torch.sum((logw - logw_) ** 2) / torch.sum(lengths) 46 | return loss 47 | 48 | 49 | def normalize(data, mu, std): 50 | if not isinstance(mu, (float, int)): 51 | if isinstance(mu, list): 52 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 53 | elif isinstance(mu, torch.Tensor): 54 | mu = mu.to(data.device) 55 | elif isinstance(mu, np.ndarray): 56 | mu = torch.from_numpy(mu).to(data.device) 57 | mu = mu.unsqueeze(-1) 58 | 59 | if not isinstance(std, (float, int)): 60 | if isinstance(std, list): 61 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 62 | elif isinstance(std, torch.Tensor): 63 | std = std.to(data.device) 64 | elif isinstance(std, np.ndarray): 65 | std = torch.from_numpy(std).to(data.device) 66 | std = std.unsqueeze(-1) 67 | 68 | return (data - mu) / std 69 | 70 | 71 | def denormalize(data, mu, std): 72 | if not isinstance(mu, float): 73 | if isinstance(mu, list): 74 | mu = torch.tensor(mu, dtype=data.dtype, device=data.device) 75 | elif isinstance(mu, torch.Tensor): 76 | mu = mu.to(data.device) 77 | elif isinstance(mu, np.ndarray): 78 | mu = torch.from_numpy(mu).to(data.device) 79 | mu = mu.unsqueeze(-1) 80 | 81 | if not isinstance(std, float): 82 | if isinstance(std, list): 83 | std = torch.tensor(std, dtype=data.dtype, device=data.device) 84 | elif isinstance(std, torch.Tensor): 85 | std = std.to(data.device) 86 | elif isinstance(std, np.ndarray): 87 | std = torch.from_numpy(std).to(data.device) 88 | std = std.unsqueeze(-1) 89 | 90 | return data * std + mu 91 | -------------------------------------------------------------------------------- /cosyvoice/flow/length_regulator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Tuple 15 | import torch.nn as nn 16 | import torch 17 | from torch.nn import functional as F 18 | from cosyvoice.utils.mask import make_pad_mask 19 | 20 | 21 | class InterpolateRegulator(nn.Module): 22 | def __init__( 23 | self, 24 | channels: int, 25 | sampling_ratios: Tuple, 26 | out_channels: int = None, 27 | groups: int = 1, 28 | ): 29 | super().__init__() 30 | self.sampling_ratios = sampling_ratios 31 | out_channels = out_channels or channels 32 | model = nn.ModuleList([]) 33 | if len(sampling_ratios) > 0: 34 | for _ in sampling_ratios: 35 | module = nn.Conv1d(channels, channels, 3, 1, 1) 36 | norm = nn.GroupNorm(groups, channels) 37 | act = nn.Mish() 38 | model.extend([module, norm, act]) 39 | model.append( 40 | nn.Conv1d(channels, out_channels, 1, 1) 41 | ) 42 | self.model = nn.Sequential(*model) 43 | 44 | def forward(self, x, ylens=None): 45 | # x in (B, T, D) 46 | mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1) 47 | x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear') 48 | out = self.model(x).transpose(1, 2).contiguous() 49 | olens = ylens 50 | return out * mask, olens 51 | 52 | def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50): 53 | # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel 54 | # NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py 55 | # x in (B, T, D) 56 | if x2.shape[1] > 40: 57 | x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 58 | x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2, 59 | mode='linear') 60 | x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear') 61 | x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2) 62 | else: 63 | x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear') 64 | if x1.shape[1] != 0: 65 | x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear') 66 | x = torch.concat([x1, x2], dim=2) 67 | else: 68 | x = x2 69 | out = self.model(x).transpose(1, 2).contiguous() 70 | return out, mel_len1 + mel_len2 71 | -------------------------------------------------------------------------------- /cosyvoice/transformer/activation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe) 2 | # 2020 Northwestern Polytechnical University (Pengcheng Guo) 3 | # 2020 Mobvoi Inc (Binbin Zhang) 4 | # 2024 Alibaba Inc (Xiang Lyu) 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | """Swish() activation function for Conformer.""" 18 | 19 | import torch 20 | from torch import nn, sin, pow 21 | from torch.nn import Parameter 22 | 23 | 24 | class Swish(torch.nn.Module): 25 | """Construct an Swish object.""" 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | """Return Swish activation function.""" 29 | return x * torch.sigmoid(x) 30 | 31 | 32 | # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license. 33 | # LICENSE is in incl_licenses directory. 34 | class Snake(nn.Module): 35 | ''' 36 | Implementation of a sine-based periodic activation function 37 | Shape: 38 | - Input: (B, C, T) 39 | - Output: (B, C, T), same shape as the input 40 | Parameters: 41 | - alpha - trainable parameter 42 | References: 43 | - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda: 44 | https://arxiv.org/abs/2006.08195 45 | Examples: 46 | >>> a1 = snake(256) 47 | >>> x = torch.randn(256) 48 | >>> x = a1(x) 49 | ''' 50 | def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False): 51 | ''' 52 | Initialization. 53 | INPUT: 54 | - in_features: shape of the input 55 | - alpha: trainable parameter 56 | alpha is initialized to 1 by default, higher values = higher-frequency. 57 | alpha will be trained along with the rest of your model. 58 | ''' 59 | super(Snake, self).__init__() 60 | self.in_features = in_features 61 | 62 | # initialize alpha 63 | self.alpha_logscale = alpha_logscale 64 | if self.alpha_logscale: # log scale alphas initialized to zeros 65 | self.alpha = Parameter(torch.zeros(in_features) * alpha) 66 | else: # linear scale alphas initialized to ones 67 | self.alpha = Parameter(torch.ones(in_features) * alpha) 68 | 69 | self.alpha.requires_grad = alpha_trainable 70 | 71 | self.no_div_by_zero = 0.000000001 72 | 73 | def forward(self, x): 74 | ''' 75 | Forward pass of the function. 76 | Applies the function to the input elementwise. 77 | Snake ∶= x + 1/a * sin^2 (xa) 78 | ''' 79 | alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T] 80 | if self.alpha_logscale: 81 | alpha = torch.exp(alpha) 82 | x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2) 83 | 84 | return x 85 | -------------------------------------------------------------------------------- /cosyvoice/hifigan/hifigan.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss 6 | from cosyvoice.utils.losses import tpr_loss, mel_loss 7 | 8 | 9 | class HiFiGan(nn.Module): 10 | def __init__(self, generator, discriminator, mel_spec_transform, 11 | multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0, 12 | tpr_loss_weight=1.0, tpr_loss_tau=0.04): 13 | super(HiFiGan, self).__init__() 14 | self.generator = generator 15 | self.discriminator = discriminator 16 | self.mel_spec_transform = mel_spec_transform 17 | self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight 18 | self.feat_match_loss_weight = feat_match_loss_weight 19 | self.tpr_loss_weight = tpr_loss_weight 20 | self.tpr_loss_tau = tpr_loss_tau 21 | 22 | def forward( 23 | self, 24 | batch: dict, 25 | device: torch.device, 26 | ) -> Dict[str, Optional[torch.Tensor]]: 27 | if batch['turn'] == 'generator': 28 | return self.forward_generator(batch, device) 29 | else: 30 | return self.forward_discriminator(batch, device) 31 | 32 | def forward_generator(self, batch, device): 33 | real_speech = batch['speech'].to(device) 34 | pitch_feat = batch['pitch_feat'].to(device) 35 | # 1. calculate generator outputs 36 | generated_speech, generated_f0 = self.generator(batch, device) 37 | # 2. calculate discriminator outputs 38 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech) 39 | # 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional] 40 | loss_gen, _ = generator_loss(y_d_gs) 41 | loss_fm = feature_loss(fmap_rs, fmap_gs) 42 | loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform) 43 | if self.tpr_loss_weight != 0: 44 | loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau) 45 | else: 46 | loss_tpr = torch.zeros(1).to(device) 47 | loss_f0 = F.l1_loss(generated_f0, pitch_feat) 48 | loss = loss_gen + self.feat_match_loss_weight * loss_fm + \ 49 | self.multi_mel_spectral_recon_loss_weight * loss_mel + \ 50 | self.tpr_loss_weight * loss_tpr + loss_f0 51 | return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0} 52 | 53 | def forward_discriminator(self, batch, device): 54 | real_speech = batch['speech'].to(device) 55 | # 1. calculate generator outputs 56 | with torch.no_grad(): 57 | generated_speech, generated_f0 = self.generator(batch, device) 58 | # 2. calculate discriminator outputs 59 | y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach()) 60 | # 3. calculate discriminator losses, tpr losses [Optional] 61 | loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs) 62 | if self.tpr_loss_weight != 0: 63 | loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau) 64 | else: 65 | loss_tpr = torch.zeros(1).to(device) 66 | loss = loss_disc + self.tpr_loss_weight * loss_tpr 67 | return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr} 68 | -------------------------------------------------------------------------------- /cosyvoice/dataset/my_processor.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import logging 15 | import random 16 | 17 | import pyarrow.parquet as pq 18 | from io import BytesIO 19 | import torch 20 | import torchaudio 21 | from torch.nn.utils.rnn import pad_sequence 22 | import torch.nn.functional as F 23 | import pyworld as pw 24 | 25 | 26 | AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'} 27 | 28 | 29 | def is_mixed_language(s): 30 | # 定义各语言的Unicode范围 31 | lang_ranges = { 32 | 'chinese': (0x4e00, 0x9fff), # 中文 33 | 'english': (0x0041, 0x005a), # 英文大写 34 | 'english_lower': (0x0061, 0x007a), # 英文小写 35 | 'japanese': (0x3040, 0x30ff), # 日文 36 | 'korean': (0xac00, 0xd7af), # 韩文 37 | 'numbers': (0x0030, 0x0039) # 数字(作为中性字符) 38 | } 39 | 40 | # 记录字符串中出现的语言 41 | present_langs = set() 42 | 43 | for char in s: 44 | code = ord(char) 45 | lang_found = None 46 | 47 | # 检查字符属于哪种语言 48 | for lang, (start, end) in lang_ranges.items(): 49 | if start <= code <= end: 50 | # 英文大小写合并为同一类别 51 | if lang == 'english_lower': 52 | lang_found = 'english' 53 | elif lang == 'chinese' or lang == 'japanese': 54 | lang_found = 'chinese+japanese' 55 | else: 56 | lang_found = lang 57 | break 58 | 59 | # if lang_found and lang_found != 'numbers': # 忽略数字的影响 60 | if lang_found: 61 | present_langs.add(lang_found) 62 | 63 | # 一旦发现两种及以上语言,可提前返回 64 | if len(present_langs) >= 2: 65 | return True 66 | 67 | return len(present_langs) >= 2 68 | 69 | 70 | def filter_mix_lang(data, mode='train'): 71 | """ Filter sample according to feature and label length 72 | Inplace operation. 73 | 74 | Args:: 75 | data: Iterable[{key, wav, label, sample_rate}] 76 | max_length: drop utterance which is greater than max_length(10ms) 77 | min_length: drop utterance which is less than min_length(10ms) 78 | token_max_length: drop utterance which is greater than 79 | token_max_length, especially when use char unit for 80 | english modeling 81 | token_min_length: drop utterance which is 82 | less than token_max_length 83 | min_output_input_ratio: minimal ration of 84 | token_length / feats_length(10ms) 85 | max_output_input_ratio: maximum ration of 86 | token_length / feats_length(10ms) 87 | 88 | Returns: 89 | Iterable[{key, wav, label, sample_rate}] 90 | """ 91 | for sample in data: 92 | if is_mixed_language(sample['text']): 93 | continue 94 | 95 | yield sample 96 | 97 | -------------------------------------------------------------------------------- /cosyvoice/bin/average_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc (Di Wu) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import argparse 18 | import glob 19 | 20 | import yaml 21 | import torch 22 | 23 | 24 | def get_args(): 25 | parser = argparse.ArgumentParser(description='average model') 26 | parser.add_argument('--dst_model', required=True, help='averaged model') 27 | parser.add_argument('--src_path', 28 | required=True, 29 | help='src model path for average') 30 | parser.add_argument('--val_best', 31 | action="store_true", 32 | help='averaged model') 33 | parser.add_argument('--num', 34 | default=5, 35 | type=int, 36 | help='nums for averaged model') 37 | 38 | args = parser.parse_args() 39 | print(args) 40 | return args 41 | 42 | 43 | def main(): 44 | args = get_args() 45 | val_scores = [] 46 | if args.val_best: 47 | yamls = glob.glob('{}/*.yaml'.format(args.src_path)) 48 | yamls = [ 49 | f for f in yamls 50 | if not (os.path.basename(f).startswith('train') 51 | or os.path.basename(f).startswith('init')) 52 | ] 53 | for y in yamls: 54 | with open(y, 'r') as f: 55 | dic_yaml = yaml.load(f, Loader=yaml.BaseLoader) 56 | loss = float(dic_yaml['loss_dict']['loss']) 57 | epoch = int(dic_yaml['epoch']) 58 | step = int(dic_yaml['step']) 59 | tag = dic_yaml['tag'] 60 | val_scores += [[epoch, step, loss, tag]] 61 | sorted_val_scores = sorted(val_scores, 62 | key=lambda x: x[2], 63 | reverse=False) 64 | print("best val (epoch, step, loss, tag) = " + 65 | str(sorted_val_scores[:args.num])) 66 | path_list = [ 67 | args.src_path + '/epoch_{}_whole.pt'.format(score[0]) 68 | for score in sorted_val_scores[:args.num] 69 | ] 70 | print(path_list) 71 | avg = {} 72 | num = args.num 73 | assert num == len(path_list) 74 | for path in path_list: 75 | print('Processing {}'.format(path)) 76 | states = torch.load(path, map_location=torch.device('cpu')) 77 | for k in states.keys(): 78 | if k not in ['step', 'epoch']: 79 | if k not in avg.keys(): 80 | avg[k] = states[k].clone() 81 | else: 82 | avg[k] += states[k] 83 | # average 84 | for k in avg.keys(): 85 | if avg[k] is not None: 86 | # pytorch 1.6 use true_divide instead of /= 87 | avg[k] = torch.true_divide(avg[k], num) 88 | print('Saving to {}'.format(args.dst_model)) 89 | torch.save(avg, args.dst_model) 90 | 91 | 92 | if __name__ == '__main__': 93 | main() 94 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at mikelei@mobvoi.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /cosyvoice/utils/class_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright [2023-11-28] 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import torch 16 | 17 | from cosyvoice.transformer.activation import Swish 18 | from cosyvoice.transformer.subsampling import ( 19 | LinearNoSubsampling, 20 | EmbedinigNoSubsampling, 21 | Conv1dSubsampling2, 22 | Conv2dSubsampling4, 23 | Conv2dSubsampling6, 24 | Conv2dSubsampling8, 25 | ) 26 | from cosyvoice.transformer.embedding import (PositionalEncoding, 27 | RelPositionalEncoding, 28 | WhisperPositionalEncoding, 29 | LearnablePositionalEncoding, 30 | NoPositionalEncoding) 31 | from cosyvoice.transformer.attention import (MultiHeadedAttention, 32 | RelPositionMultiHeadedAttention) 33 | from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding 34 | from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling 35 | from cosyvoice.llm.llm import TransformerLM, Qwen2LM 36 | from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec 37 | from cosyvoice.hifigan.generator import HiFTGenerator 38 | from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model 39 | 40 | 41 | COSYVOICE_ACTIVATION_CLASSES = { 42 | "hardtanh": torch.nn.Hardtanh, 43 | "tanh": torch.nn.Tanh, 44 | "relu": torch.nn.ReLU, 45 | "selu": torch.nn.SELU, 46 | "swish": getattr(torch.nn, "SiLU", Swish), 47 | "gelu": torch.nn.GELU, 48 | } 49 | 50 | COSYVOICE_SUBSAMPLE_CLASSES = { 51 | "linear": LinearNoSubsampling, 52 | "linear_legacy": LegacyLinearNoSubsampling, 53 | "embed": EmbedinigNoSubsampling, 54 | "conv1d2": Conv1dSubsampling2, 55 | "conv2d": Conv2dSubsampling4, 56 | "conv2d6": Conv2dSubsampling6, 57 | "conv2d8": Conv2dSubsampling8, 58 | 'paraformer_dummy': torch.nn.Identity 59 | } 60 | 61 | COSYVOICE_EMB_CLASSES = { 62 | "embed": PositionalEncoding, 63 | "abs_pos": PositionalEncoding, 64 | "rel_pos": RelPositionalEncoding, 65 | "rel_pos_espnet": EspnetRelPositionalEncoding, 66 | "no_pos": NoPositionalEncoding, 67 | "abs_pos_whisper": WhisperPositionalEncoding, 68 | "embed_learnable_pe": LearnablePositionalEncoding, 69 | } 70 | 71 | COSYVOICE_ATTENTION_CLASSES = { 72 | "selfattn": MultiHeadedAttention, 73 | "rel_selfattn": RelPositionMultiHeadedAttention, 74 | } 75 | 76 | 77 | def get_model_type(configs): 78 | # NOTE CosyVoice2Model inherits CosyVoiceModel 79 | if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 80 | return CosyVoiceModel 81 | if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator): 82 | return CosyVoice2Model 83 | raise TypeError('No valid model type found!') 84 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/rich_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Sequence 3 | 4 | import rich 5 | import rich.syntax 6 | import rich.tree 7 | from hydra.core.hydra_config import HydraConfig 8 | from lightning.pytorch.utilities import rank_zero_only 9 | from omegaconf import DictConfig, OmegaConf, open_dict 10 | from rich.prompt import Prompt 11 | 12 | from matcha.utils import pylogger 13 | 14 | log = pylogger.get_pylogger(__name__) 15 | 16 | 17 | @rank_zero_only 18 | def print_config_tree( 19 | cfg: DictConfig, 20 | print_order: Sequence[str] = ( 21 | "data", 22 | "model", 23 | "callbacks", 24 | "logger", 25 | "trainer", 26 | "paths", 27 | "extras", 28 | ), 29 | resolve: bool = False, 30 | save_to_file: bool = False, 31 | ) -> None: 32 | """Prints the contents of a DictConfig as a tree structure using the Rich library. 33 | 34 | :param cfg: A DictConfig composed by Hydra. 35 | :param print_order: Determines in what order config components are printed. Default is ``("data", "model", 36 | "callbacks", "logger", "trainer", "paths", "extras")``. 37 | :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. 38 | :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. 39 | """ 40 | style = "dim" 41 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 42 | 43 | queue = [] 44 | 45 | # add fields from `print_order` to queue 46 | for field in print_order: 47 | _ = ( 48 | queue.append(field) 49 | if field in cfg 50 | else log.warning(f"Field '{field}' not found in config. Skipping '{field}' config printing...") 51 | ) 52 | 53 | # add all the other fields to queue (not specified in `print_order`) 54 | for field in cfg: 55 | if field not in queue: 56 | queue.append(field) 57 | 58 | # generate config tree from queue 59 | for field in queue: 60 | branch = tree.add(field, style=style, guide_style=style) 61 | 62 | config_group = cfg[field] 63 | if isinstance(config_group, DictConfig): 64 | branch_content = OmegaConf.to_yaml(config_group, resolve=resolve) 65 | else: 66 | branch_content = str(config_group) 67 | 68 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 69 | 70 | # print config tree 71 | rich.print(tree) 72 | 73 | # save config tree to file 74 | if save_to_file: 75 | with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file: 76 | rich.print(tree, file=file) 77 | 78 | 79 | @rank_zero_only 80 | def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: 81 | """Prompts user to input tags from command line if no tags are provided in config. 82 | 83 | :param cfg: A DictConfig composed by Hydra. 84 | :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. 85 | """ 86 | if not cfg.get("tags"): 87 | if "id" in HydraConfig().cfg.hydra.job: 88 | raise ValueError("Specify tags before launching a multirun!") 89 | 90 | log.warning("No tags provided in config. Prompting user to input tags...") 91 | tags = Prompt.ask("Enter a list of comma separated tags", default="dev") 92 | tags = [t.strip() for t in tags.split(",") if t != ""] 93 | 94 | with open_dict(cfg): 95 | cfg.tags = tags 96 | 97 | log.info(f"Tags: {cfg.tags}") 98 | 99 | if save_to_file: 100 | with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file: 101 | rich.print(cfg.tags, file=file) 102 | -------------------------------------------------------------------------------- /cosyvoice/transformer/label_smoothing_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Label smoothing module.""" 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class LabelSmoothingLoss(nn.Module): 22 | """Label-smoothing loss. 23 | 24 | In a standard CE loss, the label's data distribution is: 25 | [0,1,2] -> 26 | [ 27 | [1.0, 0.0, 0.0], 28 | [0.0, 1.0, 0.0], 29 | [0.0, 0.0, 1.0], 30 | ] 31 | 32 | In the smoothing version CE Loss,some probabilities 33 | are taken from the true label prob (1.0) and are divided 34 | among other labels. 35 | 36 | e.g. 37 | smoothing=0.1 38 | [0,1,2] -> 39 | [ 40 | [0.9, 0.05, 0.05], 41 | [0.05, 0.9, 0.05], 42 | [0.05, 0.05, 0.9], 43 | ] 44 | 45 | Args: 46 | size (int): the number of class 47 | padding_idx (int): padding class id which will be ignored for loss 48 | smoothing (float): smoothing rate (0.0 means the conventional CE) 49 | normalize_length (bool): 50 | normalize loss by sequence length if True 51 | normalize loss by batch size if False 52 | """ 53 | 54 | def __init__(self, 55 | size: int, 56 | padding_idx: int, 57 | smoothing: float, 58 | normalize_length: bool = False): 59 | """Construct an LabelSmoothingLoss object.""" 60 | super(LabelSmoothingLoss, self).__init__() 61 | self.criterion = nn.KLDivLoss(reduction="none") 62 | self.padding_idx = padding_idx 63 | self.confidence = 1.0 - smoothing 64 | self.smoothing = smoothing 65 | self.size = size 66 | self.normalize_length = normalize_length 67 | 68 | def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 69 | """Compute loss between x and target. 70 | 71 | The model outputs and data labels tensors are flatten to 72 | (batch*seqlen, class) shape and a mask is applied to the 73 | padding part which should not be calculated for loss. 74 | 75 | Args: 76 | x (torch.Tensor): prediction (batch, seqlen, class) 77 | target (torch.Tensor): 78 | target signal masked with self.padding_id (batch, seqlen) 79 | Returns: 80 | loss (torch.Tensor) : The KL loss, scalar float value 81 | """ 82 | assert x.size(2) == self.size 83 | batch_size = x.size(0) 84 | x = x.view(-1, self.size) 85 | target = target.view(-1) 86 | # use zeros_like instead of torch.no_grad() for true_dist, 87 | # since no_grad() can not be exported by JIT 88 | true_dist = torch.zeros_like(x) 89 | true_dist.fill_(self.smoothing / (self.size - 1)) 90 | ignore = target == self.padding_idx # (B,) 91 | total = len(target) - ignore.sum().item() 92 | target = target.masked_fill(ignore, 0) # avoid -1 index 93 | true_dist.scatter_(1, target.unsqueeze(1), self.confidence) 94 | kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) 95 | denom = total if self.normalize_length else batch_size 96 | return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom 97 | -------------------------------------------------------------------------------- /cosyvoice/llm/reward_tts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import aiohttp 4 | import re 5 | import logging 6 | logging.getLogger('asyncio').setLevel(logging.WARNING) 7 | 8 | 9 | SERVER = os.getenv("WHISPER_SERVER", "http://172.16.46.216:8080") 10 | SCORE_URL = f"{SERVER.rstrip('/')}/end_to_end" 11 | HEALTH_URL = f"{SERVER.rstrip('/')}/health" 12 | 13 | # export WHISPER_SERVER=http://172.16.46.216:8080 14 | 15 | 16 | def remove_lang_tag(text: str) -> str: 17 | return re.sub(r'^<\|[^|]+?\|>', '', text) 18 | 19 | 20 | async def process_audio_sample(sampling_rate, speech_tokens_str: str, expected_answer: str) -> float: 21 | """ 22 | Process a single sample by calling the combined /end_to_end endpoint, 23 | which decodes speech tokens and transcribes the resulting audio. 24 | The service returns (among others) the transcribed text, WER, and reward. 25 | """ 26 | payload = { 27 | "speech_tokens_str": speech_tokens_str, 28 | "sample_rate": sampling_rate, 29 | "expected_text": remove_lang_tag(expected_answer), 30 | } 31 | transcribed_text = "" 32 | language = "" 33 | cer = 0.0 34 | nll = 0.0 35 | reward = 0.0 36 | cer_reward = 0.0 37 | nll_reward = 0.0 38 | try: 39 | async with aiohttp.ClientSession() as session: 40 | async with session.post( 41 | SCORE_URL, json=payload, timeout=300 42 | ) as response: 43 | if response.status == 200: 44 | result = await response.json() 45 | transcribed_text = result.get("transcribed_text", "") 46 | language = result.get("language", "") 47 | cer = result.get("cer", 0.0) 48 | nll = result.get("nll", 0.0) 49 | reward = result.get("reward", 0.0) 50 | cer_reward = result.get("cer_reward", 0.0) 51 | nll_reward = result.get("nll_reward", 0.0) 52 | else: 53 | error_text = await response.text() 54 | print( 55 | f"Error in combined endpoint: {response.status} - {error_text}" 56 | ) 57 | except Exception as e: 58 | print(f"Exception in combined endpoint request: {e}") 59 | 60 | # print( 61 | # "-" * 20, 62 | # f"\nExpected Answer:\n{expected_answer}", 63 | # f"\nTranscribed:\n{transcribed_text}", 64 | # f"\nCER: {cer}, Reward: {reward}", 65 | # f"\nCER Reward: {result.get('cer_reward', None)}, NLL Reward: {result.get('nll_reward', None)}", 66 | # ) 67 | # return reward tuple, first is CER reward, second is NLL reward, third is harmonic mean reward 68 | return cer_reward, nll_reward, reward, cer, nll, language 69 | 70 | 71 | 72 | async def wer_reward_func_async( 73 | sampling_rate, speech_tokens_list: list[str], answers: list[str] 74 | ) -> list[float]: 75 | """ 76 | Async version of the reward function that processes all samples in 77 | parallel using the combined endpoint. 78 | """ 79 | tasks = [ 80 | process_audio_sample(sampling_rate, speech_tokens, answer) 81 | for speech_tokens, answer in zip(speech_tokens_list, answers) 82 | ] 83 | rewards = await asyncio.gather(*tasks) 84 | return rewards 85 | 86 | 87 | def wer_reward_func(sampling_rate, completions, answer, **kwargs) -> list[float]: 88 | """ 89 | Synchronous interface for the async reward function. 90 | Processes all transcription requests in parallel using the combined endpoint. 91 | Expects the completions to be a list where each element is a list/dict 92 | that contains the speech token string in completion[0]['content']. 93 | """ 94 | speech_tokens_list = completions 95 | return asyncio.run(wer_reward_func_async(sampling_rate, speech_tokens_list, answer)) 96 | 97 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/utils/generate_data_statistics.py: -------------------------------------------------------------------------------- 1 | r""" 2 | The file creates a pickle file where the values needed for loading of dataset is stored and the model can load it 3 | when needed. 4 | 5 | Parameters from hparam.py will be used 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | import sys 11 | from pathlib import Path 12 | 13 | import rootutils 14 | import torch 15 | from hydra import compose, initialize 16 | from omegaconf import open_dict 17 | from tqdm.auto import tqdm 18 | 19 | from matcha.data.text_mel_datamodule import TextMelDataModule 20 | from matcha.utils.logging_utils import pylogger 21 | 22 | log = pylogger.get_pylogger(__name__) 23 | 24 | 25 | def compute_data_statistics(data_loader: torch.utils.data.DataLoader, out_channels: int): 26 | """Generate data mean and standard deviation helpful in data normalisation 27 | 28 | Args: 29 | data_loader (torch.utils.data.Dataloader): _description_ 30 | out_channels (int): mel spectrogram channels 31 | """ 32 | total_mel_sum = 0 33 | total_mel_sq_sum = 0 34 | total_mel_len = 0 35 | 36 | for batch in tqdm(data_loader, leave=False): 37 | mels = batch["y"] 38 | mel_lengths = batch["y_lengths"] 39 | 40 | total_mel_len += torch.sum(mel_lengths) 41 | total_mel_sum += torch.sum(mels) 42 | total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) 43 | 44 | data_mean = total_mel_sum / (total_mel_len * out_channels) 45 | data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) 46 | 47 | return {"mel_mean": data_mean.item(), "mel_std": data_std.item()} 48 | 49 | 50 | def main(): 51 | parser = argparse.ArgumentParser() 52 | 53 | parser.add_argument( 54 | "-i", 55 | "--input-config", 56 | type=str, 57 | default="vctk.yaml", 58 | help="The name of the yaml config file under configs/data", 59 | ) 60 | 61 | parser.add_argument( 62 | "-b", 63 | "--batch-size", 64 | type=int, 65 | default="256", 66 | help="Can have increased batch size for faster computation", 67 | ) 68 | 69 | parser.add_argument( 70 | "-f", 71 | "--force", 72 | action="store_true", 73 | default=False, 74 | required=False, 75 | help="force overwrite the file", 76 | ) 77 | args = parser.parse_args() 78 | output_file = Path(args.input_config).with_suffix(".json") 79 | 80 | if os.path.exists(output_file) and not args.force: 81 | print("File already exists. Use -f to force overwrite") 82 | sys.exit(1) 83 | 84 | with initialize(version_base="1.3", config_path="../../configs/data"): 85 | cfg = compose(config_name=args.input_config, return_hydra_config=True, overrides=[]) 86 | 87 | root_path = rootutils.find_root(search_from=__file__, indicator=".project-root") 88 | 89 | with open_dict(cfg): 90 | del cfg["hydra"] 91 | del cfg["_target_"] 92 | cfg["data_statistics"] = None 93 | cfg["seed"] = 1234 94 | cfg["batch_size"] = args.batch_size 95 | cfg["train_filelist_path"] = str(os.path.join(root_path, cfg["train_filelist_path"])) 96 | cfg["valid_filelist_path"] = str(os.path.join(root_path, cfg["valid_filelist_path"])) 97 | 98 | text_mel_datamodule = TextMelDataModule(**cfg) 99 | text_mel_datamodule.setup() 100 | data_loader = text_mel_datamodule.train_dataloader() 101 | log.info("Dataloader loaded! Now computing stats...") 102 | params = compute_data_statistics(data_loader, cfg["n_feats"]) 103 | print(params) 104 | json.dump( 105 | params, 106 | open(output_file, "w"), 107 | ) 108 | 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /tools/parse_options.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright 2012 Johns Hopkins University (Author: Daniel Povey); 4 | # Arnab Ghoshal, Karel Vesely 5 | 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 13 | # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 14 | # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 15 | # MERCHANTABLITY OR NON-INFRINGEMENT. 16 | # See the Apache 2 License for the specific language governing permissions and 17 | # limitations under the License. 18 | 19 | 20 | # Parse command-line options. 21 | # To be sourced by another script (as in ". parse_options.sh"). 22 | # Option format is: --option-name arg 23 | # and shell variable "option_name" gets set to value "arg." 24 | # The exception is --help, which takes no arguments, but prints the 25 | # $help_message variable (if defined). 26 | 27 | 28 | ### 29 | ### The --config file options have lower priority to command line 30 | ### options, so we need to import them first... 31 | ### 32 | 33 | # Now import all the configs specified by command-line, in left-to-right order 34 | for ((argpos=1; argpos<$#; argpos++)); do 35 | if [ "${!argpos}" == "--config" ]; then 36 | argpos_plus1=$((argpos+1)) 37 | config=${!argpos_plus1} 38 | [ ! -r $config ] && echo "$0: missing config '$config'" && exit 1 39 | . $config # source the config file. 40 | fi 41 | done 42 | 43 | 44 | ### 45 | ### Now we process the command line options 46 | ### 47 | while true; do 48 | [ -z "${1:-}" ] && break; # break if there are no arguments 49 | case "$1" in 50 | # If the enclosing script is called with --help option, print the help 51 | # message and exit. Scripts should put help messages in $help_message 52 | --help|-h) if [ -z "$help_message" ]; then echo "No help found." 1>&2; 53 | else printf "$help_message\n" 1>&2 ; fi; 54 | exit 0 ;; 55 | --*=*) echo "$0: options to scripts must be of the form --name value, got '$1'" 56 | exit 1 ;; 57 | # If the first command-line argument begins with "--" (e.g. --foo-bar), 58 | # then work out the variable name as $name, which will equal "foo_bar". 59 | --*) name=`echo "$1" | sed s/^--// | sed s/-/_/g`; 60 | # Next we test whether the variable in question is undefned-- if so it's 61 | # an invalid option and we die. Note: $0 evaluates to the name of the 62 | # enclosing script. 63 | # The test [ -z ${foo_bar+xxx} ] will return true if the variable foo_bar 64 | # is undefined. We then have to wrap this test inside "eval" because 65 | # foo_bar is itself inside a variable ($name). 66 | eval '[ -z "${'$name'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1; 67 | 68 | oldval="`eval echo \\$$name`"; 69 | # Work out whether we seem to be expecting a Boolean argument. 70 | if [ "$oldval" == "true" ] || [ "$oldval" == "false" ]; then 71 | was_bool=true; 72 | else 73 | was_bool=false; 74 | fi 75 | 76 | # Set the variable to the right value-- the escaped quotes make it work if 77 | # the option had spaces, like --cmd "queue.pl -sync y" 78 | eval $name=\"$2\"; 79 | 80 | # Check that Boolean-valued arguments are really Boolean. 81 | if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then 82 | echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2 83 | exit 1; 84 | fi 85 | shift 2; 86 | ;; 87 | *) break; 88 | esac 89 | done 90 | 91 | 92 | # Check for an empty argument to the --cmd option, which can easily occur as a 93 | # result of scripting errors. 94 | [ ! -z "${cmd+xxx}" ] && [ -z "$cmd" ] && echo "$0: empty argument to --cmd option" 1>&2 && exit 1; 95 | 96 | 97 | true; # so this script returns exit code 0. 98 | -------------------------------------------------------------------------------- /runtime/python/fastapi/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import argparse 15 | import logging 16 | import requests 17 | import torch 18 | import torchaudio 19 | import numpy as np 20 | 21 | 22 | def main(): 23 | url = "http://{}:{}/inference_{}".format(args.host, args.port, args.mode) 24 | if args.mode == 'sft': 25 | payload = { 26 | 'tts_text': args.tts_text, 27 | 'spk_id': args.spk_id 28 | } 29 | response = requests.request("GET", url, data=payload, stream=True) 30 | elif args.mode == 'zero_shot': 31 | payload = { 32 | 'tts_text': args.tts_text, 33 | 'prompt_text': args.prompt_text 34 | } 35 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 36 | response = requests.request("GET", url, data=payload, files=files, stream=True) 37 | elif args.mode == 'cross_lingual': 38 | payload = { 39 | 'tts_text': args.tts_text, 40 | } 41 | files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))] 42 | response = requests.request("GET", url, data=payload, files=files, stream=True) 43 | else: 44 | payload = { 45 | 'tts_text': args.tts_text, 46 | 'spk_id': args.spk_id, 47 | 'instruct_text': args.instruct_text 48 | } 49 | response = requests.request("GET", url, data=payload, stream=True) 50 | tts_audio = b'' 51 | for r in response.iter_content(chunk_size=16000): 52 | tts_audio += r 53 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 54 | logging.info('save response to {}'.format(args.tts_wav)) 55 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 56 | logging.info('get response') 57 | 58 | 59 | if __name__ == "__main__": 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument('--host', 62 | type=str, 63 | default='0.0.0.0') 64 | parser.add_argument('--port', 65 | type=int, 66 | default='50000') 67 | parser.add_argument('--mode', 68 | default='sft', 69 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 70 | help='request mode') 71 | parser.add_argument('--tts_text', 72 | type=str, 73 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 74 | parser.add_argument('--spk_id', 75 | type=str, 76 | default='中文女') 77 | parser.add_argument('--prompt_text', 78 | type=str, 79 | default='希望你以后能够做的比我还好呦。') 80 | parser.add_argument('--prompt_wav', 81 | type=str, 82 | default='../../../asset/zero_shot_prompt.wav') 83 | parser.add_argument('--instruct_text', 84 | type=str, 85 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 86 | Fights with fervor for justice, but struggles with impulsiveness.') 87 | parser.add_argument('--tts_wav', 88 | type=str, 89 | default='demo.wav') 90 | args = parser.parse_args() 91 | prompt_sr, target_sr = 16000, 22050 92 | main() 93 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/text/cleaners.py: -------------------------------------------------------------------------------- 1 | """ from https://github.com/keithito/tacotron 2 | 3 | Cleaners are transformations that run over the input text at both training and eval time. 4 | 5 | Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" 6 | hyperparameter. Some cleaners are English-specific. You'll typically want to use: 7 | 1. "english_cleaners" for English text 8 | 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using 9 | the Unidecode library (https://pypi.python.org/pypi/Unidecode) 10 | 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update 11 | the symbols in symbols.py to match your data). 12 | """ 13 | 14 | import logging 15 | import re 16 | 17 | import phonemizer 18 | import piper_phonemize 19 | from unidecode import unidecode 20 | 21 | # To avoid excessive logging we set the log level of the phonemizer package to Critical 22 | critical_logger = logging.getLogger("phonemizer") 23 | critical_logger.setLevel(logging.CRITICAL) 24 | 25 | # Intializing the phonemizer globally significantly reduces the speed 26 | # now the phonemizer is not initialising at every call 27 | # Might be less flexible, but it is much-much faster 28 | global_phonemizer = phonemizer.backend.EspeakBackend( 29 | language="en-us", 30 | preserve_punctuation=True, 31 | with_stress=True, 32 | language_switch="remove-flags", 33 | logger=critical_logger, 34 | ) 35 | 36 | 37 | # Regular expression matching whitespace: 38 | _whitespace_re = re.compile(r"\s+") 39 | 40 | # List of (regular expression, replacement) pairs for abbreviations: 41 | _abbreviations = [ 42 | (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) 43 | for x in [ 44 | ("mrs", "misess"), 45 | ("mr", "mister"), 46 | ("dr", "doctor"), 47 | ("st", "saint"), 48 | ("co", "company"), 49 | ("jr", "junior"), 50 | ("maj", "major"), 51 | ("gen", "general"), 52 | ("drs", "doctors"), 53 | ("rev", "reverend"), 54 | ("lt", "lieutenant"), 55 | ("hon", "honorable"), 56 | ("sgt", "sergeant"), 57 | ("capt", "captain"), 58 | ("esq", "esquire"), 59 | ("ltd", "limited"), 60 | ("col", "colonel"), 61 | ("ft", "fort"), 62 | ] 63 | ] 64 | 65 | 66 | def expand_abbreviations(text): 67 | for regex, replacement in _abbreviations: 68 | text = re.sub(regex, replacement, text) 69 | return text 70 | 71 | 72 | def lowercase(text): 73 | return text.lower() 74 | 75 | 76 | def collapse_whitespace(text): 77 | return re.sub(_whitespace_re, " ", text) 78 | 79 | 80 | def convert_to_ascii(text): 81 | return unidecode(text) 82 | 83 | 84 | def basic_cleaners(text): 85 | """Basic pipeline that lowercases and collapses whitespace without transliteration.""" 86 | text = lowercase(text) 87 | text = collapse_whitespace(text) 88 | return text 89 | 90 | 91 | def transliteration_cleaners(text): 92 | """Pipeline for non-English text that transliterates to ASCII.""" 93 | text = convert_to_ascii(text) 94 | text = lowercase(text) 95 | text = collapse_whitespace(text) 96 | return text 97 | 98 | 99 | def english_cleaners2(text): 100 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 101 | text = convert_to_ascii(text) 102 | text = lowercase(text) 103 | text = expand_abbreviations(text) 104 | phonemes = global_phonemizer.phonemize([text], strip=True, njobs=1)[0] 105 | phonemes = collapse_whitespace(phonemes) 106 | return phonemes 107 | 108 | 109 | def english_cleaners_piper(text): 110 | """Pipeline for English text, including abbreviation expansion. + punctuation + stress""" 111 | text = convert_to_ascii(text) 112 | text = lowercase(text) 113 | text = expand_abbreviations(text) 114 | phonemes = "".join(piper_phonemize.phonemize_espeak(text=text, voice="en-US")[0]) 115 | phonemes = collapse_whitespace(phonemes) 116 | return phonemes 117 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_jit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import print_function 16 | 17 | import argparse 18 | import logging 19 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 20 | import os 21 | import sys 22 | import torch 23 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 24 | sys.path.append('{}/../..'.format(ROOT_DIR)) 25 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 26 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 27 | from cosyvoice.utils.file_utils import logging 28 | 29 | 30 | def get_args(): 31 | parser = argparse.ArgumentParser(description='export your model for deployment') 32 | parser.add_argument('--model_dir', 33 | type=str, 34 | default='pretrained_models/CosyVoice-300M', 35 | help='local path') 36 | args = parser.parse_args() 37 | print(args) 38 | return args 39 | 40 | 41 | def get_optimized_script(model, preserved_attrs=[]): 42 | script = torch.jit.script(model) 43 | if preserved_attrs != []: 44 | script = torch.jit.freeze(script, preserved_attrs=preserved_attrs) 45 | else: 46 | script = torch.jit.freeze(script) 47 | script = torch.jit.optimize_for_inference(script) 48 | return script 49 | 50 | 51 | def main(): 52 | args = get_args() 53 | logging.basicConfig(level=logging.DEBUG, 54 | format='%(asctime)s %(levelname)s %(message)s') 55 | 56 | torch._C._jit_set_fusion_strategy([('STATIC', 1)]) 57 | torch._C._jit_set_profiling_mode(False) 58 | torch._C._jit_set_profiling_executor(False) 59 | 60 | try: 61 | model = CosyVoice(args.model_dir) 62 | except Exception: 63 | try: 64 | model = CosyVoice2(args.model_dir) 65 | except Exception: 66 | raise TypeError('no valid model_type!') 67 | 68 | if not isinstance(model, CosyVoice2): 69 | # 1. export llm text_encoder 70 | llm_text_encoder = model.model.llm.text_encoder 71 | script = get_optimized_script(llm_text_encoder) 72 | script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir)) 73 | script = get_optimized_script(llm_text_encoder.half()) 74 | script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir)) 75 | logging.info('successfully export llm_text_encoder') 76 | 77 | # 2. export llm llm 78 | llm_llm = model.model.llm.llm 79 | script = get_optimized_script(llm_llm, ['forward_chunk']) 80 | script.save('{}/llm.llm.fp32.zip'.format(args.model_dir)) 81 | script = get_optimized_script(llm_llm.half(), ['forward_chunk']) 82 | script.save('{}/llm.llm.fp16.zip'.format(args.model_dir)) 83 | logging.info('successfully export llm_llm') 84 | 85 | # 3. export flow encoder 86 | flow_encoder = model.model.flow.encoder 87 | script = get_optimized_script(flow_encoder) 88 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 89 | script = get_optimized_script(flow_encoder.half()) 90 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 91 | logging.info('successfully export flow_encoder') 92 | else: 93 | # 3. export flow encoder 94 | flow_encoder = model.model.flow.encoder 95 | script = get_optimized_script(flow_encoder) 96 | script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir)) 97 | script = get_optimized_script(flow_encoder.half()) 98 | script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir)) 99 | logging.info('successfully export flow_encoder') 100 | 101 | 102 | if __name__ == '__main__': 103 | main() 104 | -------------------------------------------------------------------------------- /runtime/python/fastapi/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | import argparse 17 | import logging 18 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 19 | from fastapi import FastAPI, UploadFile, Form, File 20 | from fastapi.responses import StreamingResponse 21 | from fastapi.middleware.cors import CORSMiddleware 22 | import uvicorn 23 | import numpy as np 24 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 25 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 26 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 27 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 28 | from cosyvoice.utils.file_utils import load_wav 29 | 30 | app = FastAPI() 31 | # set cross region allowance 32 | app.add_middleware( 33 | CORSMiddleware, 34 | allow_origins=["*"], 35 | allow_credentials=True, 36 | allow_methods=["*"], 37 | allow_headers=["*"]) 38 | 39 | 40 | def generate_data(model_output): 41 | for i in model_output: 42 | tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 43 | yield tts_audio 44 | 45 | 46 | @app.get("/inference_sft") 47 | @app.post("/inference_sft") 48 | async def inference_sft(tts_text: str = Form(), spk_id: str = Form()): 49 | model_output = cosyvoice.inference_sft(tts_text, spk_id) 50 | return StreamingResponse(generate_data(model_output)) 51 | 52 | 53 | @app.get("/inference_zero_shot") 54 | @app.post("/inference_zero_shot") 55 | async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()): 56 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 57 | model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k) 58 | return StreamingResponse(generate_data(model_output)) 59 | 60 | 61 | @app.get("/inference_cross_lingual") 62 | @app.post("/inference_cross_lingual") 63 | async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()): 64 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 65 | model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k) 66 | return StreamingResponse(generate_data(model_output)) 67 | 68 | 69 | @app.get("/inference_instruct") 70 | @app.post("/inference_instruct") 71 | async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()): 72 | model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text) 73 | return StreamingResponse(generate_data(model_output)) 74 | 75 | 76 | @app.get("/inference_instruct2") 77 | @app.post("/inference_instruct2") 78 | async def inference_instruct2(tts_text: str = Form(), instruct_text: str = Form(), prompt_wav: UploadFile = File()): 79 | prompt_speech_16k = load_wav(prompt_wav.file, 16000) 80 | model_output = cosyvoice.inference_instruct2(tts_text, instruct_text, prompt_speech_16k) 81 | return StreamingResponse(generate_data(model_output)) 82 | 83 | 84 | if __name__ == '__main__': 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--port', 87 | type=int, 88 | default=50000) 89 | parser.add_argument('--model_dir', 90 | type=str, 91 | default='iic/CosyVoice-300M', 92 | help='local path or modelscope repo id') 93 | args = parser.parse_args() 94 | try: 95 | cosyvoice = CosyVoice(args.model_dir) 96 | except Exception: 97 | try: 98 | cosyvoice = CosyVoice2(args.model_dir) 99 | except Exception: 100 | raise TypeError('no valid model_type!') 101 | uvicorn.run(app, host="0.0.0.0", port=args.port) 102 | -------------------------------------------------------------------------------- /cosyvoice/vllm/cosyvoice2.py: -------------------------------------------------------------------------------- 1 | # SPDX-License-Identifier: Apache-2.0 2 | 3 | # Adapted from 4 | # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py 5 | # Copyright 2024 The Qwen team. 6 | # Copyright 2023 The vLLM team. 7 | # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. 8 | # 9 | # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX 10 | # and OPT implementations in this library. It has been modified from its 11 | # original forms to accommodate minor architectural differences compared 12 | # to GPT-NeoX and OPT used by the Meta AI team that trained the model. 13 | # 14 | # Licensed under the Apache License, Version 2.0 (the "License"); 15 | # you may not use this file except in compliance with the License. 16 | # You may obtain a copy of the License at 17 | # 18 | # http://www.apache.org/licenses/LICENSE-2.0 19 | # 20 | # Unless required by applicable law or agreed to in writing, software 21 | # distributed under the License is distributed on an "AS IS" BASIS, 22 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 23 | # See the License for the specific language governing permissions and 24 | # limitations under the License. 25 | """Inference-only Qwen2 model compatible with HuggingFace weights.""" 26 | from vllm.model_executor.models.qwen2 import * 27 | 28 | 29 | class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): 30 | packed_modules_mapping = { 31 | "qkv_proj": [ 32 | "q_proj", 33 | "k_proj", 34 | "v_proj", 35 | ], 36 | "gate_up_proj": [ 37 | "gate_proj", 38 | "up_proj", 39 | ], 40 | } 41 | 42 | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): 43 | super().__init__() 44 | config = vllm_config.model_config.hf_config 45 | quant_config = vllm_config.quant_config 46 | lora_config = vllm_config.lora_config 47 | 48 | self.config = config 49 | self.lora_config = lora_config 50 | 51 | self.quant_config = quant_config 52 | self.model = Qwen2Model(vllm_config=vllm_config, 53 | prefix=maybe_prefix(prefix, "model")) 54 | 55 | if get_pp_group().is_last_rank: 56 | if config.tie_word_embeddings: 57 | self.lm_head = self.model.embed_tokens 58 | else: 59 | self.lm_head = ParallelLMHead(config.vocab_size, 60 | config.hidden_size, 61 | True, 62 | quant_config=quant_config, 63 | prefix=maybe_prefix( 64 | prefix, "lm_head")) 65 | else: 66 | self.lm_head = PPMissingLayer() 67 | 68 | self.logits_processor = LogitsProcessor(config.vocab_size) 69 | 70 | self.make_empty_intermediate_tensors = ( 71 | self.model.make_empty_intermediate_tensors) 72 | 73 | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: 74 | return self.model.get_input_embeddings(input_ids) 75 | 76 | def forward( 77 | self, 78 | input_ids: torch.Tensor, 79 | positions: torch.Tensor, 80 | intermediate_tensors: Optional[IntermediateTensors] = None, 81 | inputs_embeds: Optional[torch.Tensor] = None, 82 | ) -> Union[torch.Tensor, IntermediateTensors]: 83 | hidden_states = self.model(input_ids, positions, intermediate_tensors, 84 | inputs_embeds) 85 | return hidden_states 86 | 87 | def compute_logits( 88 | self, 89 | hidden_states: torch.Tensor, 90 | sampling_metadata: SamplingMetadata, 91 | ) -> Optional[torch.Tensor]: 92 | logits = self.logits_processor(self.lm_head, hidden_states, 93 | sampling_metadata, self.lm_head.bias) 94 | return logits 95 | 96 | def load_weights(self, weights: Iterable[tuple[str, 97 | torch.Tensor]]) -> set[str]: 98 | loader = AutoWeightsLoader( 99 | self, 100 | skip_prefixes=(["lm_head."] 101 | if self.config.tie_word_embeddings else None), 102 | ) 103 | return loader.load_weights(weights) 104 | -------------------------------------------------------------------------------- /cosyvoice/transformer/positionwise_feed_forward.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Positionwise feed forward layer definition.""" 16 | 17 | import torch 18 | 19 | 20 | class PositionwiseFeedForward(torch.nn.Module): 21 | """Positionwise feed forward layer. 22 | 23 | FeedForward are appied on each position of the sequence. 24 | The output dim is same with the input dim. 25 | 26 | Args: 27 | idim (int): Input dimenstion. 28 | hidden_units (int): The number of hidden units. 29 | dropout_rate (float): Dropout rate. 30 | activation (torch.nn.Module): Activation function 31 | """ 32 | 33 | def __init__( 34 | self, 35 | idim: int, 36 | hidden_units: int, 37 | dropout_rate: float, 38 | activation: torch.nn.Module = torch.nn.ReLU(), 39 | ): 40 | """Construct a PositionwiseFeedForward object.""" 41 | super(PositionwiseFeedForward, self).__init__() 42 | self.w_1 = torch.nn.Linear(idim, hidden_units) 43 | self.activation = activation 44 | self.dropout = torch.nn.Dropout(dropout_rate) 45 | self.w_2 = torch.nn.Linear(hidden_units, idim) 46 | 47 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 48 | """Forward function. 49 | 50 | Args: 51 | xs: input tensor (B, L, D) 52 | Returns: 53 | output tensor, (B, L, D) 54 | """ 55 | return self.w_2(self.dropout(self.activation(self.w_1(xs)))) 56 | 57 | 58 | class MoEFFNLayer(torch.nn.Module): 59 | """ 60 | Mixture of expert with Positionwise feed forward layer 61 | See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf 62 | The output dim is same with the input dim. 63 | 64 | Modified from https://github.com/Lightning-AI/lit-gpt/pull/823 65 | https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 66 | Args: 67 | n_expert: number of expert. 68 | n_expert_per_token: The actual number of experts used for each frame 69 | idim (int): Input dimenstion. 70 | hidden_units (int): The number of hidden units. 71 | dropout_rate (float): Dropout rate. 72 | activation (torch.nn.Module): Activation function 73 | """ 74 | 75 | def __init__( 76 | self, 77 | n_expert: int, 78 | n_expert_per_token: int, 79 | idim: int, 80 | hidden_units: int, 81 | dropout_rate: float, 82 | activation: torch.nn.Module = torch.nn.ReLU(), 83 | ): 84 | super(MoEFFNLayer, self).__init__() 85 | self.gate = torch.nn.Linear(idim, n_expert, bias=False) 86 | self.experts = torch.nn.ModuleList( 87 | PositionwiseFeedForward(idim, hidden_units, dropout_rate, 88 | activation) for _ in range(n_expert)) 89 | self.n_expert_per_token = n_expert_per_token 90 | 91 | def forward(self, xs: torch.Tensor) -> torch.Tensor: 92 | """Foward function. 93 | Args: 94 | xs: input tensor (B, L, D) 95 | Returns: 96 | output tensor, (B, L, D) 97 | 98 | """ 99 | B, L, D = xs.size( 100 | ) # batch size, sequence length, embedding dimension (idim) 101 | xs = xs.view(-1, D) # (B*L, D) 102 | router = self.gate(xs) # (B*L, n_expert) 103 | logits, indices = torch.topk( 104 | router, self.n_expert_per_token 105 | ) # probs:(B*L, n_expert), indices: (B*L, n_expert) 106 | weights = torch.nn.functional.softmax( 107 | logits, dim=1, 108 | dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token) 109 | output = torch.zeros_like(xs) # (B*L, D) 110 | for i, expert in enumerate(self.experts): 111 | mask = indices == i 112 | batch_idx, ith_expert = torch.where(mask) 113 | output[batch_idx] += weights[batch_idx, ith_expert, None] * expert( 114 | xs[batch_idx]) 115 | return output.view(B, L, D) 116 | -------------------------------------------------------------------------------- /runtime/python/grpc/server.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | from concurrent import futures 17 | import argparse 18 | import cosyvoice_pb2 19 | import cosyvoice_pb2_grpc 20 | import logging 21 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 22 | import grpc 23 | import torch 24 | import numpy as np 25 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 26 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 27 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 28 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 29 | 30 | logging.basicConfig(level=logging.DEBUG, 31 | format='%(asctime)s %(levelname)s %(message)s') 32 | 33 | 34 | class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer): 35 | def __init__(self, args): 36 | try: 37 | self.cosyvoice = CosyVoice(args.model_dir, trt_concurrent=args.max_conc) 38 | except Exception: 39 | try: 40 | self.cosyvoice = CosyVoice2(args.model_dir, trt_concurrent=args.max_conc) 41 | except Exception: 42 | raise TypeError('no valid model_type!') 43 | logging.info('grpc service initialized') 44 | 45 | def Inference(self, request, context): 46 | if request.HasField('sft_request'): 47 | logging.info('get sft inference request') 48 | model_output = self.cosyvoice.inference_sft(request.sft_request.tts_text, request.sft_request.spk_id) 49 | elif request.HasField('zero_shot_request'): 50 | logging.info('get zero_shot inference request') 51 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 52 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 53 | model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, 54 | request.zero_shot_request.prompt_text, 55 | prompt_speech_16k) 56 | elif request.HasField('cross_lingual_request'): 57 | logging.info('get cross_lingual inference request') 58 | prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0) 59 | prompt_speech_16k = prompt_speech_16k.float() / (2**15) 60 | model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k) 61 | else: 62 | logging.info('get instruct inference request') 63 | model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, 64 | request.instruct_request.spk_id, 65 | request.instruct_request.instruct_text) 66 | 67 | logging.info('send inference response') 68 | for i in model_output: 69 | response = cosyvoice_pb2.Response() 70 | response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes() 71 | yield response 72 | 73 | 74 | def main(): 75 | grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc) 76 | cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer) 77 | grpcServer.add_insecure_port('0.0.0.0:{}'.format(args.port)) 78 | grpcServer.start() 79 | logging.info("server listening on 0.0.0.0:{}".format(args.port)) 80 | grpcServer.wait_for_termination() 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--port', 86 | type=int, 87 | default=50000) 88 | parser.add_argument('--max_conc', 89 | type=int, 90 | default=4) 91 | parser.add_argument('--model_dir', 92 | type=str, 93 | default='iic/CosyVoice-300M', 94 | help='local path or modelscope repo id') 95 | args = parser.parse_args() 96 | main() 97 | -------------------------------------------------------------------------------- /cosyvoice/utils/frontend_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | import regex 17 | chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+') 18 | 19 | 20 | # whether contain chinese character 21 | def contains_chinese(text): 22 | return bool(chinese_char_pattern.search(text)) 23 | 24 | 25 | # replace special symbol 26 | def replace_corner_mark(text): 27 | text = text.replace('²', '平方') 28 | text = text.replace('³', '立方') 29 | return text 30 | 31 | 32 | # remove meaningless symbol 33 | def remove_bracket(text): 34 | text = text.replace('(', '').replace(')', '') 35 | text = text.replace('【', '').replace('】', '') 36 | text = text.replace('`', '').replace('`', '') 37 | text = text.replace("——", " ") 38 | return text 39 | 40 | 41 | # spell Arabic numerals 42 | def spell_out_number(text: str, inflect_parser): 43 | new_text = [] 44 | st = None 45 | for i, c in enumerate(text): 46 | if not c.isdigit(): 47 | if st is not None: 48 | num_str = inflect_parser.number_to_words(text[st: i]) 49 | new_text.append(num_str) 50 | st = None 51 | new_text.append(c) 52 | else: 53 | if st is None: 54 | st = i 55 | if st is not None and st < len(text): 56 | num_str = inflect_parser.number_to_words(text[st:]) 57 | new_text.append(num_str) 58 | return ''.join(new_text) 59 | 60 | 61 | # split paragrah logic: 62 | # 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len 63 | # 2. cal sentence len according to lang 64 | # 3. split sentence according to puncatation 65 | def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False): 66 | def calc_utt_length(_text: str): 67 | if lang == "zh": 68 | return len(_text) 69 | else: 70 | return len(tokenize(_text)) 71 | 72 | def should_merge(_text: str): 73 | if lang == "zh": 74 | return len(_text) < merge_len 75 | else: 76 | return len(tokenize(_text)) < merge_len 77 | 78 | if lang == "zh": 79 | pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';'] 80 | else: 81 | pounc = ['.', '?', '!', ';', ':'] 82 | if comma_split: 83 | pounc.extend([',', ',']) 84 | 85 | if text[-1] not in pounc: 86 | if lang == "zh": 87 | text += "。" 88 | else: 89 | text += "." 90 | 91 | st = 0 92 | utts = [] 93 | for i, c in enumerate(text): 94 | if c in pounc: 95 | if len(text[st: i]) > 0: 96 | utts.append(text[st: i] + c) 97 | if i + 1 < len(text) and text[i + 1] in ['"', '”']: 98 | tmp = utts.pop(-1) 99 | utts.append(tmp + text[i + 1]) 100 | st = i + 2 101 | else: 102 | st = i + 1 103 | 104 | final_utts = [] 105 | cur_utt = "" 106 | for utt in utts: 107 | if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n: 108 | final_utts.append(cur_utt) 109 | cur_utt = "" 110 | cur_utt = cur_utt + utt 111 | if len(cur_utt) > 0: 112 | if should_merge(cur_utt) and len(final_utts) != 0: 113 | final_utts[-1] = final_utts[-1] + cur_utt 114 | else: 115 | final_utts.append(cur_utt) 116 | 117 | return final_utts 118 | 119 | 120 | # remove blank between chinese character 121 | def replace_blank(text: str): 122 | out_str = [] 123 | for i, c in enumerate(text): 124 | if c == " ": 125 | if ((text[i + 1].isascii() and text[i + 1] != " ") and 126 | (text[i - 1].isascii() and text[i - 1] != " ")): 127 | out_str.append(c) 128 | else: 129 | out_str.append(c) 130 | return "".join(out_str) 131 | 132 | 133 | def is_only_punctuation(text): 134 | # Regular expression: Match strings that consist only of punctuation marks or are empty. 135 | punctuation_pattern = r'^[\p{P}\p{S}]*$' 136 | return bool(regex.fullmatch(punctuation_pattern, text)) 137 | -------------------------------------------------------------------------------- /runtime/python/grpc/client.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | import sys 16 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 17 | sys.path.append('{}/../../..'.format(ROOT_DIR)) 18 | sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR)) 19 | import logging 20 | import argparse 21 | import torchaudio 22 | import cosyvoice_pb2 23 | import cosyvoice_pb2_grpc 24 | import grpc 25 | import torch 26 | import numpy as np 27 | from cosyvoice.utils.file_utils import load_wav 28 | 29 | 30 | def main(): 31 | with grpc.insecure_channel("{}:{}".format(args.host, args.port)) as channel: 32 | stub = cosyvoice_pb2_grpc.CosyVoiceStub(channel) 33 | request = cosyvoice_pb2.Request() 34 | if args.mode == 'sft': 35 | logging.info('send sft request') 36 | sft_request = cosyvoice_pb2.sftRequest() 37 | sft_request.spk_id = args.spk_id 38 | sft_request.tts_text = args.tts_text 39 | request.sft_request.CopyFrom(sft_request) 40 | elif args.mode == 'zero_shot': 41 | logging.info('send zero_shot request') 42 | zero_shot_request = cosyvoice_pb2.zeroshotRequest() 43 | zero_shot_request.tts_text = args.tts_text 44 | zero_shot_request.prompt_text = args.prompt_text 45 | prompt_speech = load_wav(args.prompt_wav, 16000) 46 | zero_shot_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 47 | request.zero_shot_request.CopyFrom(zero_shot_request) 48 | elif args.mode == 'cross_lingual': 49 | logging.info('send cross_lingual request') 50 | cross_lingual_request = cosyvoice_pb2.crosslingualRequest() 51 | cross_lingual_request.tts_text = args.tts_text 52 | prompt_speech = load_wav(args.prompt_wav, 16000) 53 | cross_lingual_request.prompt_audio = (prompt_speech.numpy() * (2**15)).astype(np.int16).tobytes() 54 | request.cross_lingual_request.CopyFrom(cross_lingual_request) 55 | else: 56 | logging.info('send instruct request') 57 | instruct_request = cosyvoice_pb2.instructRequest() 58 | instruct_request.tts_text = args.tts_text 59 | instruct_request.spk_id = args.spk_id 60 | instruct_request.instruct_text = args.instruct_text 61 | request.instruct_request.CopyFrom(instruct_request) 62 | 63 | response = stub.Inference(request) 64 | tts_audio = b'' 65 | for r in response: 66 | tts_audio += r.tts_audio 67 | tts_speech = torch.from_numpy(np.array(np.frombuffer(tts_audio, dtype=np.int16))).unsqueeze(dim=0) 68 | logging.info('save response to {}'.format(args.tts_wav)) 69 | torchaudio.save(args.tts_wav, tts_speech, target_sr) 70 | logging.info('get response') 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | parser.add_argument('--host', 76 | type=str, 77 | default='0.0.0.0') 78 | parser.add_argument('--port', 79 | type=int, 80 | default='50000') 81 | parser.add_argument('--mode', 82 | default='sft', 83 | choices=['sft', 'zero_shot', 'cross_lingual', 'instruct'], 84 | help='request mode') 85 | parser.add_argument('--tts_text', 86 | type=str, 87 | default='你好,我是通义千问语音合成大模型,请问有什么可以帮您的吗?') 88 | parser.add_argument('--spk_id', 89 | type=str, 90 | default='中文女') 91 | parser.add_argument('--prompt_text', 92 | type=str, 93 | default='希望你以后能够做的比我还好呦。') 94 | parser.add_argument('--prompt_wav', 95 | type=str, 96 | default='../../../asset/zero_shot_prompt.wav') 97 | parser.add_argument('--instruct_text', 98 | type=str, 99 | default='Theo \'Crimson\', is a fiery, passionate rebel leader. \ 100 | Fights with fervor for justice, but struggles with impulsiveness.') 101 | parser.add_argument('--tts_wav', 102 | type=str, 103 | default='demo.wav') 104 | args = parser.parse_args() 105 | prompt_sr, target_sr = 16000, 22050 106 | main() 107 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/train.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional, Tuple 2 | 3 | import hydra 4 | import lightning as L 5 | import rootutils 6 | from lightning import Callback, LightningDataModule, LightningModule, Trainer 7 | from lightning.pytorch.loggers import Logger 8 | from omegaconf import DictConfig 9 | 10 | from matcha import utils 11 | 12 | rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 13 | # ------------------------------------------------------------------------------------ # 14 | # the setup_root above is equivalent to: 15 | # - adding project root dir to PYTHONPATH 16 | # (so you don't need to force user to install project as a package) 17 | # (necessary before importing any local modules e.g. `from src import utils`) 18 | # - setting up PROJECT_ROOT environment variable 19 | # (which is used as a base for paths in "configs/paths/default.yaml") 20 | # (this way all filepaths are the same no matter where you run the code) 21 | # - loading environment variables from ".env" in root dir 22 | # 23 | # you can remove it if you: 24 | # 1. either install project as a package or move entry files to project root dir 25 | # 2. set `root_dir` to "." in "configs/paths/default.yaml" 26 | # 27 | # more info: https://github.com/ashleve/rootutils 28 | # ------------------------------------------------------------------------------------ # 29 | 30 | 31 | log = utils.get_pylogger(__name__) 32 | 33 | 34 | @utils.task_wrapper 35 | def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: 36 | """Trains the model. Can additionally evaluate on a testset, using best weights obtained during 37 | training. 38 | 39 | This method is wrapped in optional @task_wrapper decorator, that controls the behavior during 40 | failure. Useful for multiruns, saving info about the crash, etc. 41 | 42 | :param cfg: A DictConfig configuration composed by Hydra. 43 | :return: A tuple with metrics and dict with all instantiated objects. 44 | """ 45 | # set seed for random number generators in pytorch, numpy and python.random 46 | if cfg.get("seed"): 47 | L.seed_everything(cfg.seed, workers=True) 48 | 49 | log.info(f"Instantiating datamodule <{cfg.data._target_}>") # pylint: disable=protected-access 50 | datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data) 51 | 52 | log.info(f"Instantiating model <{cfg.model._target_}>") # pylint: disable=protected-access 53 | model: LightningModule = hydra.utils.instantiate(cfg.model) 54 | 55 | log.info("Instantiating callbacks...") 56 | callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks")) 57 | 58 | log.info("Instantiating loggers...") 59 | logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger")) 60 | 61 | log.info(f"Instantiating trainer <{cfg.trainer._target_}>") # pylint: disable=protected-access 62 | trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger) 63 | 64 | object_dict = { 65 | "cfg": cfg, 66 | "datamodule": datamodule, 67 | "model": model, 68 | "callbacks": callbacks, 69 | "logger": logger, 70 | "trainer": trainer, 71 | } 72 | 73 | if logger: 74 | log.info("Logging hyperparameters!") 75 | utils.log_hyperparameters(object_dict) 76 | 77 | if cfg.get("train"): 78 | log.info("Starting training!") 79 | trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) 80 | 81 | train_metrics = trainer.callback_metrics 82 | 83 | if cfg.get("test"): 84 | log.info("Starting testing!") 85 | ckpt_path = trainer.checkpoint_callback.best_model_path 86 | if ckpt_path == "": 87 | log.warning("Best ckpt not found! Using current weights for testing...") 88 | ckpt_path = None 89 | trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) 90 | log.info(f"Best ckpt path: {ckpt_path}") 91 | 92 | test_metrics = trainer.callback_metrics 93 | 94 | # merge train and test metrics 95 | metric_dict = {**train_metrics, **test_metrics} 96 | 97 | return metric_dict, object_dict 98 | 99 | 100 | @hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") 101 | def main(cfg: DictConfig) -> Optional[float]: 102 | """Main entry point for training. 103 | 104 | :param cfg: DictConfig configuration composed by Hydra. 105 | :return: Optional[float] with optimized metric value. 106 | """ 107 | # apply extra utilities 108 | # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) 109 | utils.extras(cfg) 110 | 111 | # train the model 112 | metric_dict, _ = train(cfg) 113 | 114 | # safely retrieve metric value for hydra-based hyperparameter optimization 115 | metric_value = utils.get_metric_value(metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")) 116 | 117 | # return optimized metric 118 | return metric_value 119 | 120 | 121 | if __name__ == "__main__": 122 | main() # pylint: disable=no-value-for-parameter 123 | -------------------------------------------------------------------------------- /cosyvoice/bin/export_onnx.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com) 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from __future__ import print_function 17 | 18 | import argparse 19 | import logging 20 | logging.getLogger('matplotlib').setLevel(logging.WARNING) 21 | import os 22 | import sys 23 | import onnxruntime 24 | import random 25 | import torch 26 | from tqdm import tqdm 27 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 28 | sys.path.append('{}/../..'.format(ROOT_DIR)) 29 | sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR)) 30 | from cosyvoice.cli.cosyvoice import CosyVoice, CosyVoice2 31 | from cosyvoice.utils.file_utils import logging 32 | 33 | 34 | def get_dummy_input(batch_size, seq_len, out_channels, device): 35 | x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 36 | mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device) 37 | mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 38 | t = torch.rand((batch_size), dtype=torch.float32, device=device) 39 | spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device) 40 | cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device) 41 | return x, mask, mu, t, spks, cond 42 | 43 | 44 | def get_args(): 45 | parser = argparse.ArgumentParser(description='export your model for deployment') 46 | parser.add_argument('--model_dir', 47 | type=str, 48 | default='pretrained_models/CosyVoice-300M', 49 | help='local path') 50 | args = parser.parse_args() 51 | print(args) 52 | return args 53 | 54 | 55 | @torch.no_grad() 56 | def main(): 57 | args = get_args() 58 | logging.basicConfig(level=logging.DEBUG, 59 | format='%(asctime)s %(levelname)s %(message)s') 60 | 61 | try: 62 | model = CosyVoice(args.model_dir) 63 | except Exception: 64 | try: 65 | model = CosyVoice2(args.model_dir) 66 | except Exception: 67 | raise TypeError('no valid model_type!') 68 | 69 | # 1. export flow decoder estimator 70 | estimator = model.model.flow.decoder.estimator 71 | estimator.eval() 72 | 73 | device = model.model.device 74 | batch_size, seq_len = 2, 256 75 | out_channels = model.model.flow.decoder.estimator.out_channels 76 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device) 77 | torch.onnx.export( 78 | estimator, 79 | (x, mask, mu, t, spks, cond), 80 | '{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 81 | export_params=True, 82 | opset_version=18, 83 | do_constant_folding=True, 84 | input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'], 85 | output_names=['estimator_out'], 86 | dynamic_axes={ 87 | 'x': {2: 'seq_len'}, 88 | 'mask': {2: 'seq_len'}, 89 | 'mu': {2: 'seq_len'}, 90 | 'cond': {2: 'seq_len'}, 91 | 'estimator_out': {2: 'seq_len'}, 92 | } 93 | ) 94 | 95 | # 2. test computation consistency 96 | option = onnxruntime.SessionOptions() 97 | option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 98 | option.intra_op_num_threads = 1 99 | providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider'] 100 | estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), 101 | sess_options=option, providers=providers) 102 | 103 | for _ in tqdm(range(10)): 104 | x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device) 105 | output_pytorch = estimator(x, mask, mu, t, spks, cond) 106 | ort_inputs = { 107 | 'x': x.cpu().numpy(), 108 | 'mask': mask.cpu().numpy(), 109 | 'mu': mu.cpu().numpy(), 110 | 't': t.cpu().numpy(), 111 | 'spks': spks.cpu().numpy(), 112 | 'cond': cond.cpu().numpy() 113 | } 114 | output_onnx = estimator_onnx.run(None, ort_inputs)[0] 115 | torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4) 116 | logging.info('successfully export estimator') 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/models/components/flow_matching.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from matcha.models.components.decoder import Decoder 7 | from matcha.utils.pylogger import get_pylogger 8 | 9 | log = get_pylogger(__name__) 10 | 11 | 12 | class BASECFM(torch.nn.Module, ABC): 13 | def __init__( 14 | self, 15 | n_feats, 16 | cfm_params, 17 | n_spks=1, 18 | spk_emb_dim=128, 19 | ): 20 | super().__init__() 21 | self.n_feats = n_feats 22 | self.n_spks = n_spks 23 | self.spk_emb_dim = spk_emb_dim 24 | self.solver = cfm_params.solver 25 | if hasattr(cfm_params, "sigma_min"): 26 | self.sigma_min = cfm_params.sigma_min 27 | else: 28 | self.sigma_min = 1e-4 29 | 30 | self.estimator = None 31 | 32 | @torch.inference_mode() 33 | def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None): 34 | """Forward diffusion 35 | 36 | Args: 37 | mu (torch.Tensor): output of encoder 38 | shape: (batch_size, n_feats, mel_timesteps) 39 | mask (torch.Tensor): output_mask 40 | shape: (batch_size, 1, mel_timesteps) 41 | n_timesteps (int): number of diffusion steps 42 | temperature (float, optional): temperature for scaling noise. Defaults to 1.0. 43 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 44 | shape: (batch_size, spk_emb_dim) 45 | cond: Not used but kept for future purposes 46 | 47 | Returns: 48 | sample: generated mel-spectrogram 49 | shape: (batch_size, n_feats, mel_timesteps) 50 | """ 51 | z = torch.randn_like(mu) * temperature 52 | t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) 53 | return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond) 54 | 55 | def solve_euler(self, x, t_span, mu, mask, spks, cond): 56 | """ 57 | Fixed euler solver for ODEs. 58 | Args: 59 | x (torch.Tensor): random noise 60 | t_span (torch.Tensor): n_timesteps interpolated 61 | shape: (n_timesteps + 1,) 62 | mu (torch.Tensor): output of encoder 63 | shape: (batch_size, n_feats, mel_timesteps) 64 | mask (torch.Tensor): output_mask 65 | shape: (batch_size, 1, mel_timesteps) 66 | spks (torch.Tensor, optional): speaker ids. Defaults to None. 67 | shape: (batch_size, spk_emb_dim) 68 | cond: Not used but kept for future purposes 69 | """ 70 | t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0] 71 | 72 | # I am storing this because I can later plot it by putting a debugger here and saving it to a file 73 | # Or in future might add like a return_all_steps flag 74 | sol = [] 75 | 76 | for step in range(1, len(t_span)): 77 | dphi_dt = self.estimator(x, mask, mu, t, spks, cond) 78 | 79 | x = x + dt * dphi_dt 80 | t = t + dt 81 | sol.append(x) 82 | if step < len(t_span) - 1: 83 | dt = t_span[step + 1] - t 84 | 85 | return sol[-1] 86 | 87 | def compute_loss(self, x1, mask, mu, spks=None, cond=None): 88 | """Computes diffusion loss 89 | 90 | Args: 91 | x1 (torch.Tensor): Target 92 | shape: (batch_size, n_feats, mel_timesteps) 93 | mask (torch.Tensor): target mask 94 | shape: (batch_size, 1, mel_timesteps) 95 | mu (torch.Tensor): output of encoder 96 | shape: (batch_size, n_feats, mel_timesteps) 97 | spks (torch.Tensor, optional): speaker embedding. Defaults to None. 98 | shape: (batch_size, spk_emb_dim) 99 | 100 | Returns: 101 | loss: conditional flow matching loss 102 | y: conditional flow 103 | shape: (batch_size, n_feats, mel_timesteps) 104 | """ 105 | b, _, t = mu.shape 106 | 107 | # random timestep 108 | t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype) 109 | # sample noise p(x_0) 110 | z = torch.randn_like(x1) 111 | 112 | y = (1 - (1 - self.sigma_min) * t) * z + t * x1 113 | u = x1 - (1 - self.sigma_min) * z 114 | 115 | loss = F.mse_loss(self.estimator(y, mask, mu, t.squeeze(), spks), u, reduction="sum") / ( 116 | torch.sum(mask) * u.shape[1] 117 | ) 118 | return loss, y 119 | 120 | 121 | class CFM(BASECFM): 122 | def __init__(self, in_channels, out_channel, cfm_params, decoder_params, n_spks=1, spk_emb_dim=64): 123 | super().__init__( 124 | n_feats=in_channels, 125 | cfm_params=cfm_params, 126 | n_spks=n_spks, 127 | spk_emb_dim=spk_emb_dim, 128 | ) 129 | 130 | in_channels = in_channels + (spk_emb_dim if n_spks > 1 else 0) 131 | # Just change the architecture of the estimator here 132 | self.estimator = Decoder(in_channels=in_channels, out_channels=out_channel, **decoder_params) 133 | -------------------------------------------------------------------------------- /cosyvoice/transformer/decoder_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019 Shigeki Karita 2 | # 2020 Mobvoi Inc (Binbin Zhang) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Decoder self-attention layer definition.""" 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | from torch import nn 20 | 21 | 22 | class DecoderLayer(nn.Module): 23 | """Single decoder layer module. 24 | 25 | Args: 26 | size (int): Input dimension. 27 | self_attn (torch.nn.Module): Self-attention module instance. 28 | `MultiHeadedAttention` instance can be used as the argument. 29 | src_attn (torch.nn.Module): Inter-attention module instance. 30 | `MultiHeadedAttention` instance can be used as the argument. 31 | If `None` is passed, Inter-attention is not used, such as 32 | CIF, GPT, and other decoder only model. 33 | feed_forward (torch.nn.Module): Feed-forward module instance. 34 | `PositionwiseFeedForward` instance can be used as the argument. 35 | dropout_rate (float): Dropout rate. 36 | normalize_before (bool): 37 | True: use layer_norm before each sub-block. 38 | False: to use layer_norm after each sub-block. 39 | """ 40 | 41 | def __init__( 42 | self, 43 | size: int, 44 | self_attn: nn.Module, 45 | src_attn: Optional[nn.Module], 46 | feed_forward: nn.Module, 47 | dropout_rate: float, 48 | normalize_before: bool = True, 49 | ): 50 | """Construct an DecoderLayer object.""" 51 | super().__init__() 52 | self.size = size 53 | self.self_attn = self_attn 54 | self.src_attn = src_attn 55 | self.feed_forward = feed_forward 56 | self.norm1 = nn.LayerNorm(size, eps=1e-5) 57 | self.norm2 = nn.LayerNorm(size, eps=1e-5) 58 | self.norm3 = nn.LayerNorm(size, eps=1e-5) 59 | self.dropout = nn.Dropout(dropout_rate) 60 | self.normalize_before = normalize_before 61 | 62 | def forward( 63 | self, 64 | tgt: torch.Tensor, 65 | tgt_mask: torch.Tensor, 66 | memory: torch.Tensor, 67 | memory_mask: torch.Tensor, 68 | cache: Optional[torch.Tensor] = None 69 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 70 | """Compute decoded features. 71 | 72 | Args: 73 | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). 74 | tgt_mask (torch.Tensor): Mask for input tensor 75 | (#batch, maxlen_out). 76 | memory (torch.Tensor): Encoded memory 77 | (#batch, maxlen_in, size). 78 | memory_mask (torch.Tensor): Encoded memory mask 79 | (#batch, maxlen_in). 80 | cache (torch.Tensor): cached tensors. 81 | (#batch, maxlen_out - 1, size). 82 | 83 | Returns: 84 | torch.Tensor: Output tensor (#batch, maxlen_out, size). 85 | torch.Tensor: Mask for output tensor (#batch, maxlen_out). 86 | torch.Tensor: Encoded memory (#batch, maxlen_in, size). 87 | torch.Tensor: Encoded memory mask (#batch, maxlen_in). 88 | 89 | """ 90 | residual = tgt 91 | if self.normalize_before: 92 | tgt = self.norm1(tgt) 93 | 94 | if cache is None: 95 | tgt_q = tgt 96 | tgt_q_mask = tgt_mask 97 | else: 98 | # compute only the last frame query keeping dim: max_time_out -> 1 99 | assert cache.shape == ( 100 | tgt.shape[0], 101 | tgt.shape[1] - 1, 102 | self.size, 103 | ), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" 104 | tgt_q = tgt[:, -1:, :] 105 | residual = residual[:, -1:, :] 106 | tgt_q_mask = tgt_mask[:, -1:, :] 107 | 108 | x = residual + self.dropout( 109 | self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0]) 110 | if not self.normalize_before: 111 | x = self.norm1(x) 112 | 113 | if self.src_attn is not None: 114 | residual = x 115 | if self.normalize_before: 116 | x = self.norm2(x) 117 | x = residual + self.dropout( 118 | self.src_attn(x, memory, memory, memory_mask)[0]) 119 | if not self.normalize_before: 120 | x = self.norm2(x) 121 | 122 | residual = x 123 | if self.normalize_before: 124 | x = self.norm3(x) 125 | x = residual + self.dropout(self.feed_forward(x)) 126 | if not self.normalize_before: 127 | x = self.norm3(x) 128 | 129 | if cache is not None: 130 | x = torch.cat([cache, x], dim=1) 131 | 132 | return x, tgt_mask, memory, memory_mask 133 | -------------------------------------------------------------------------------- /cosyvoice/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang) 2 | # 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import random 17 | import math 18 | from functools import partial 19 | 20 | import torch 21 | import torch.distributed as dist 22 | from torch.utils.data import IterableDataset 23 | from cosyvoice.utils.file_utils import read_lists 24 | 25 | 26 | class Processor(IterableDataset): 27 | 28 | def __init__(self, source, f, *args, **kw): 29 | assert callable(f) 30 | self.source = source 31 | self.f = f 32 | self.args = args 33 | self.kw = kw 34 | 35 | def set_epoch(self, epoch): 36 | self.source.set_epoch(epoch) 37 | 38 | def __iter__(self): 39 | """ Return an iterator over the source dataset processed by the 40 | given processor. 41 | """ 42 | assert self.source is not None 43 | assert callable(self.f) 44 | return self.f(iter(self.source), *self.args, **self.kw) 45 | 46 | def apply(self, f): 47 | assert callable(f) 48 | return Processor(self, f, *self.args, **self.kw) 49 | 50 | 51 | class DistributedSampler: 52 | 53 | def __init__(self, shuffle=True, partition=True): 54 | self.epoch = -1 55 | self.update() 56 | self.shuffle = shuffle 57 | self.partition = partition 58 | 59 | def update(self): 60 | assert dist.is_available() 61 | if dist.is_initialized(): 62 | self.rank = dist.get_rank() 63 | self.world_size = dist.get_world_size() 64 | else: 65 | self.rank = 0 66 | self.world_size = 1 67 | worker_info = torch.utils.data.get_worker_info() 68 | if worker_info is None: 69 | self.worker_id = 0 70 | self.num_workers = 1 71 | else: 72 | self.worker_id = worker_info.id 73 | self.num_workers = worker_info.num_workers 74 | return dict(rank=self.rank, 75 | world_size=self.world_size, 76 | worker_id=self.worker_id, 77 | num_workers=self.num_workers) 78 | 79 | def set_epoch(self, epoch): 80 | self.epoch = epoch 81 | 82 | def sample(self, data): 83 | """ Sample data according to rank/world_size/num_workers 84 | 85 | Args: 86 | data(List): input data list 87 | 88 | Returns: 89 | List: data list after sample 90 | """ 91 | data = list(range(len(data))) 92 | # force datalist even 93 | if self.partition: 94 | if self.shuffle: 95 | random.Random(self.epoch).shuffle(data) 96 | if len(data) < self.world_size: 97 | data = data * math.ceil(self.world_size / len(data)) 98 | data = data[:self.world_size] 99 | data = data[self.rank::self.world_size] 100 | if len(data) < self.num_workers: 101 | data = data * math.ceil(self.num_workers / len(data)) 102 | data = data[:self.num_workers] 103 | data = data[self.worker_id::self.num_workers] 104 | return data 105 | 106 | 107 | class DataList(IterableDataset): 108 | 109 | def __init__(self, lists, shuffle=True, partition=True): 110 | self.lists = lists 111 | self.sampler = DistributedSampler(shuffle, partition) 112 | 113 | def set_epoch(self, epoch): 114 | self.sampler.set_epoch(epoch) 115 | 116 | def __iter__(self): 117 | sampler_info = self.sampler.update() 118 | indexes = self.sampler.sample(self.lists) 119 | for index in indexes: 120 | data = dict(src=self.lists[index]) 121 | data.update(sampler_info) 122 | yield data 123 | 124 | 125 | def Dataset(data_list_file, 126 | data_pipeline, 127 | mode='train', 128 | gan=False, 129 | dpo=False, 130 | shuffle=True, 131 | partition=True): 132 | """ Construct dataset from arguments 133 | 134 | We have two shuffle stage in the Dataset. The first is global 135 | shuffle at shards tar/raw file level. The second is global shuffle 136 | at training samples level. 137 | 138 | Args: 139 | data_type(str): raw/shard 140 | tokenizer (BaseTokenizer): tokenizer to tokenize 141 | partition(bool): whether to do data partition in terms of rank 142 | """ 143 | lists = read_lists(data_list_file) 144 | dataset = DataList(lists, 145 | shuffle=shuffle, 146 | partition=partition) 147 | # map partial arg to padding func 148 | data_pipeline[-1] = partial(data_pipeline[-1], gan=gan, dpo=dpo) 149 | for func in data_pipeline: 150 | dataset = Processor(dataset, func, mode=mode) 151 | return dataset 152 | -------------------------------------------------------------------------------- /tools/make_parquet_list.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import argparse 16 | import logging 17 | import os 18 | import json 19 | from tqdm import tqdm 20 | import pandas as pd 21 | import multiprocessing 22 | import time 23 | import torch 24 | 25 | 26 | def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file): 27 | start_time = time.time() 28 | data_list = [] 29 | for utt in tqdm(utt_list): 30 | data = open(utt2wav[utt], 'rb').read() 31 | data_list.append(data) 32 | wav_list = [utt2wav[utt] for utt in utt_list] 33 | text_list = [utt2text[utt] for utt in utt_list] 34 | spk_list = [utt2spk[utt] for utt in utt_list] 35 | uttembedding_list = [utt2embedding[utt] for utt in utt_list] 36 | spkembedding_list = [spk2embedding[utt2spk[utt]] for utt in utt_list] 37 | speech_token_list = [utt2speech_token.get(utt, []) for utt in utt_list] 38 | if args.dpo: 39 | reject_speech_token_list = [utt2reject_speech_token[utt] for utt in utt_list] 40 | 41 | # 保存到parquet,utt2parquet_file,spk2parquet_file 42 | df = pd.DataFrame() 43 | df['utt'] = utt_list 44 | df['wav'] = wav_list 45 | df['audio_data'] = data_list 46 | df['text'] = text_list 47 | df['spk'] = spk_list 48 | df['utt_embedding'] = uttembedding_list 49 | df['spk_embedding'] = spkembedding_list 50 | df['speech_token'] = speech_token_list 51 | if args.dpo: 52 | df['reject_speech_token'] = reject_speech_token_list 53 | df.to_parquet(parquet_file) 54 | with open(utt2parquet_file, 'w') as f: 55 | json.dump({k: parquet_file for k in utt_list}, f, ensure_ascii=False, indent=2) 56 | with open(spk2parquet_file, 'w') as f: 57 | json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2) 58 | logging.info('spend time {}'.format(time.time() - start_time)) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument('--num_utts_per_parquet', 64 | type=int, 65 | default=1000, 66 | help='num utts per parquet') 67 | parser.add_argument('--num_processes', 68 | type=int, 69 | default=1, 70 | help='num processes for make parquets') 71 | parser.add_argument('--src_dir', 72 | type=str) 73 | parser.add_argument('--des_dir', 74 | type=str) 75 | parser.add_argument('--dpo', 76 | action='store_true', 77 | default=False, 78 | help='Use Direct Preference Optimization') 79 | args = parser.parse_args() 80 | 81 | utt2wav, utt2text, utt2spk = {}, {}, {} 82 | with open('{}/wav.scp'.format(args.src_dir)) as f: 83 | for l in f: 84 | l = l.replace('\n', '').split() 85 | utt2wav[l[0]] = l[1] 86 | with open('{}/text'.format(args.src_dir)) as f: 87 | for l in f: 88 | l = l.replace('\n', '').split() 89 | utt2text[l[0]] = ' '.join(l[1:]) 90 | with open('{}/utt2spk'.format(args.src_dir)) as f: 91 | for l in f: 92 | l = l.replace('\n', '').split() 93 | utt2spk[l[0]] = l[1] 94 | utt2embedding = torch.load('{}/utt2embedding.pt'.format(args.src_dir)) 95 | spk2embedding = torch.load('{}/spk2embedding.pt'.format(args.src_dir)) 96 | utt2speech_token = torch.load('{}/utt2speech_token.pt'.format(args.src_dir)) 97 | if args.dpo: 98 | utt2reject_speech_token = torch.load('{}_reject/utt2speech_token.pt'.format(args.src_dir)) 99 | utts = list(utt2wav.keys()) 100 | 101 | # Using process pool to speedup 102 | pool = multiprocessing.Pool(processes=args.num_processes) 103 | parquet_list, utt2parquet_list, spk2parquet_list = [], [], [] 104 | for i, j in enumerate(range(0, len(utts), args.num_utts_per_parquet)): 105 | parquet_file = os.path.join(args.des_dir, 'parquet_{:09d}.tar'.format(i)) 106 | utt2parquet_file = os.path.join(args.des_dir, 'utt2parquet_{:09d}.json'.format(i)) 107 | spk2parquet_file = os.path.join(args.des_dir, 'spk2parquet_{:09d}.json'.format(i)) 108 | parquet_list.append(parquet_file) 109 | utt2parquet_list.append(utt2parquet_file) 110 | spk2parquet_list.append(spk2parquet_file) 111 | pool.apply_async(job, (utts[j: j + args.num_utts_per_parquet], parquet_file, utt2parquet_file, spk2parquet_file)) 112 | pool.close() 113 | pool.join() 114 | 115 | with open('{}/data.list'.format(args.des_dir), 'w', encoding='utf8') as f1, \ 116 | open('{}/utt2data.list'.format(args.des_dir), 'w', encoding='utf8') as f2, \ 117 | open('{}/spk2data.list'.format(args.des_dir), 'w', encoding='utf8') as f3: 118 | for name in parquet_list: 119 | f1.write(name + '\n') 120 | for name in utt2parquet_list: 121 | f2.write(name + '\n') 122 | for name in spk2parquet_list: 123 | f3.write(name + '\n') 124 | -------------------------------------------------------------------------------- /cosyvoice/transformer/convolution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) 2 | # 2024 Alibaba Inc (Xiang Lyu) 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Modified from ESPnet(https://github.com/espnet/espnet) 16 | """ConvolutionModule definition.""" 17 | 18 | from typing import Tuple 19 | 20 | import torch 21 | from torch import nn 22 | 23 | 24 | class ConvolutionModule(nn.Module): 25 | """ConvolutionModule in Conformer model.""" 26 | 27 | def __init__(self, 28 | channels: int, 29 | kernel_size: int = 15, 30 | activation: nn.Module = nn.ReLU(), 31 | norm: str = "batch_norm", 32 | causal: bool = False, 33 | bias: bool = True): 34 | """Construct an ConvolutionModule object. 35 | Args: 36 | channels (int): The number of channels of conv layers. 37 | kernel_size (int): Kernel size of conv layers. 38 | causal (int): Whether use causal convolution or not 39 | """ 40 | super().__init__() 41 | 42 | self.pointwise_conv1 = nn.Conv1d( 43 | channels, 44 | 2 * channels, 45 | kernel_size=1, 46 | stride=1, 47 | padding=0, 48 | bias=bias, 49 | ) 50 | # self.lorder is used to distinguish if it's a causal convolution, 51 | # if self.lorder > 0: it's a causal convolution, the input will be 52 | # padded with self.lorder frames on the left in forward. 53 | # else: it's a symmetrical convolution 54 | if causal: 55 | padding = 0 56 | self.lorder = kernel_size - 1 57 | else: 58 | # kernel_size should be an odd number for none causal convolution 59 | assert (kernel_size - 1) % 2 == 0 60 | padding = (kernel_size - 1) // 2 61 | self.lorder = 0 62 | self.depthwise_conv = nn.Conv1d( 63 | channels, 64 | channels, 65 | kernel_size, 66 | stride=1, 67 | padding=padding, 68 | groups=channels, 69 | bias=bias, 70 | ) 71 | 72 | assert norm in ['batch_norm', 'layer_norm'] 73 | if norm == "batch_norm": 74 | self.use_layer_norm = False 75 | self.norm = nn.BatchNorm1d(channels) 76 | else: 77 | self.use_layer_norm = True 78 | self.norm = nn.LayerNorm(channels) 79 | 80 | self.pointwise_conv2 = nn.Conv1d( 81 | channels, 82 | channels, 83 | kernel_size=1, 84 | stride=1, 85 | padding=0, 86 | bias=bias, 87 | ) 88 | self.activation = activation 89 | 90 | def forward( 91 | self, 92 | x: torch.Tensor, 93 | mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool), 94 | cache: torch.Tensor = torch.zeros((0, 0, 0)), 95 | ) -> Tuple[torch.Tensor, torch.Tensor]: 96 | """Compute convolution module. 97 | Args: 98 | x (torch.Tensor): Input tensor (#batch, time, channels). 99 | mask_pad (torch.Tensor): used for batch padding (#batch, 1, time), 100 | (0, 0, 0) means fake mask. 101 | cache (torch.Tensor): left context cache, it is only 102 | used in causal convolution (#batch, channels, cache_t), 103 | (0, 0, 0) meas fake cache. 104 | Returns: 105 | torch.Tensor: Output tensor (#batch, time, channels). 106 | """ 107 | # exchange the temporal dimension and the feature dimension 108 | x = x.transpose(1, 2) # (#batch, channels, time) 109 | 110 | # mask batch padding 111 | if mask_pad.size(2) > 0: # time > 0 112 | x.masked_fill_(~mask_pad, 0.0) 113 | 114 | if self.lorder > 0: 115 | if cache.size(2) == 0: # cache_t == 0 116 | x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0) 117 | else: 118 | assert cache.size(0) == x.size(0) # equal batch 119 | assert cache.size(1) == x.size(1) # equal channel 120 | x = torch.cat((cache, x), dim=2) 121 | assert (x.size(2) > self.lorder) 122 | new_cache = x[:, :, -self.lorder:] 123 | else: 124 | # It's better we just return None if no cache is required, 125 | # However, for JIT export, here we just fake one tensor instead of 126 | # None. 127 | new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device) 128 | 129 | # GLU mechanism 130 | x = self.pointwise_conv1(x) # (batch, 2*channel, dim) 131 | x = nn.functional.glu(x, dim=1) # (batch, channel, dim) 132 | 133 | # 1D Depthwise Conv 134 | x = self.depthwise_conv(x) 135 | if self.use_layer_norm: 136 | x = x.transpose(1, 2) 137 | x = self.activation(self.norm(x)) 138 | if self.use_layer_norm: 139 | x = x.transpose(1, 2) 140 | x = self.pointwise_conv2(x) 141 | # mask batch padding 142 | if mask_pad.size(2) > 0: # time > 0 143 | x.masked_fill_(~mask_pad, 0.0) 144 | 145 | return x.transpose(1, 2), new_cache 146 | -------------------------------------------------------------------------------- /third_party/Matcha-TTS/matcha/hifigan/README.md: -------------------------------------------------------------------------------- 1 | # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis 2 | 3 | ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae 4 | 5 | In our [paper](https://arxiv.org/abs/2010.05646), 6 | we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.
7 | We provide our implementation and pretrained models as open source in this repository. 8 | 9 | **Abstract :** 10 | Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms. 11 | Although such methods improve the sampling efficiency and memory usage, 12 | their sample quality has not yet reached that of autoregressive and flow-based generative models. 13 | In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis. 14 | As speech audio consists of sinusoidal signals with various periods, 15 | we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality. 16 | A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method 17 | demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than 18 | real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen 19 | speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times 20 | faster than real-time on CPU with comparable quality to an autoregressive counterpart. 21 | 22 | Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples. 23 | 24 | ## Pre-requisites 25 | 26 | 1. Python >= 3.6 27 | 2. Clone this repository. 28 | 3. Install python requirements. Please refer [requirements.txt](requirements.txt) 29 | 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/). 30 | And move all wav files to `LJSpeech-1.1/wavs` 31 | 32 | ## Training 33 | 34 | ``` 35 | python train.py --config config_v1.json 36 | ``` 37 | 38 | To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.
39 | Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.
40 | You can change the path by adding `--checkpoint_path` option. 41 | 42 | Validation loss during training with V1 generator.
43 | ![validation loss](./validation_loss.png) 44 | 45 | ## Pretrained Model 46 | 47 | You can also use pretrained models we provide.
48 | [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)
49 | Details of each folder are as in follows: 50 | 51 | | Folder Name | Generator | Dataset | Fine-Tuned | 52 | | ------------ | --------- | --------- | ------------------------------------------------------ | 53 | | LJ_V1 | V1 | LJSpeech | No | 54 | | LJ_V2 | V2 | LJSpeech | No | 55 | | LJ_V3 | V3 | LJSpeech | No | 56 | | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 57 | | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 58 | | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) | 59 | | VCTK_V1 | V1 | VCTK | No | 60 | | VCTK_V2 | V2 | VCTK | No | 61 | | VCTK_V3 | V3 | VCTK | No | 62 | | UNIVERSAL_V1 | V1 | Universal | No | 63 | 64 | We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets. 65 | 66 | ## Fine-Tuning 67 | 68 | 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.
69 | The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.
70 | Example: 71 | ` Audio File : LJ001-0001.wav 72 | Mel-Spectrogram File : LJ001-0001.npy` 73 | 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.
74 | 3. Run the following command. 75 | ``` 76 | python train.py --fine_tuning True --config config_v1.json 77 | ``` 78 | For other command line options, please refer to the training section. 79 | 80 | ## Inference from wav file 81 | 82 | 1. Make `test_files` directory and copy wav files into the directory. 83 | 2. Run the following command. 84 | ` python inference.py --checkpoint_file [generator checkpoint file path]` 85 | Generated wav files are saved in `generated_files` by default.
86 | You can change the path by adding `--output_dir` option. 87 | 88 | ## Inference for end-to-end speech synthesis 89 | 90 | 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.
91 | You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2), 92 | [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth. 93 | 2. Run the following command. 94 | ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]` 95 | Generated wav files are saved in `generated_files_from_mel` by default.
96 | You can change the path by adding `--output_dir` option. 97 | 98 | ## Acknowledgements 99 | 100 | We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips) 101 | and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this. 102 | --------------------------------------------------------------------------------