├── .gitignore ├── LICENSE ├── README.md ├── README_ZH.md ├── assets ├── docs │ ├── voice_clone.md │ └── voice_clone_ZH.md └── speakers │ └── 2222.pt ├── chattts_plus ├── __init__.py ├── commons │ ├── __init__.py │ ├── constants.py │ ├── logger.py │ ├── norm.py │ ├── onnx2trt.py │ ├── text_utils.py │ └── utils.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ └── collator.py ├── models │ ├── __init__.py │ ├── dvae.py │ ├── gpt.py │ ├── llama.py │ ├── processors.py │ └── tokenizer.py ├── pipelines │ ├── __init__.py │ └── chattts_plus_pipeline.py └── trt_models │ ├── __init__.py │ ├── base_model.py │ ├── gpt_trt.py │ ├── llama_trt_model.py │ └── predictor.py ├── configs ├── accelerate │ └── deepspeed_config.yaml ├── infer │ ├── chattts_plus.yaml │ └── chattts_plus_trt.yaml └── train │ ├── train_speaker_embedding.yaml │ └── train_voice_clone_lora.yaml ├── demos └── notebooklm-podcast │ ├── extract_files_to_texts.py │ ├── llm_api.py │ ├── requirements.txt │ ├── speaker_pt │ ├── en_man_5200.pt │ ├── en_man_8200.pt │ ├── en_man_9400.pt │ ├── en_man_9500.pt │ ├── en_woman_1200.pt │ ├── en_woman_4600.pt │ ├── en_woman_5600.pt │ ├── zh_man_1888.pt │ ├── zh_man_2155.pt │ ├── zh_man_54.pt │ ├── zh_woman_1528.pt │ ├── zh_woman_492.pt │ └── zh_woman_621.pt │ ├── tts.py │ └── utils.py ├── requirements.txt ├── scripts └── conversions │ └── convert_train_list.py ├── setup.py ├── tests ├── test_models.py └── test_pipelines.py ├── train_lora.py ├── update.bat ├── webui.bat └── webui.py /.gitignore: -------------------------------------------------------------------------------- 1 | results 2 | checkpoints 3 | .idea 4 | data 5 | *.egg-info 6 | .DS_store 7 | logs 8 | */__pycache__/ 9 | scripts/conversions 10 | *.pyc 11 | venv 12 | dist 13 | build 14 | outputs -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ChatTTSPlus: Extension of ChatTTS 2 | 3 | 中文 | English 4 | 5 | ChatTTSPlus is an extension of [ChatTTS](https://github.com/2noise/ChatTTS), adding features such as TensorRT acceleration, voice cloning, and mobile model deployment. 6 | 7 | **If you find this project useful, please give it a star! ✨✨** 8 | 9 | ### Some fun demos based on ChatTTSPlus 10 | * NotebookLM podcast: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jz8HPdRe_igoNjMSv0RaTn3l2c3seYFT?usp=sharing). 11 | Use ChatTTSPlus to turn the `AnimateAnyone` paper into a podcast. 12 | 13 | 14 | 15 | ### New Features 16 | - [x] Refactored ChatTTS code in a way I'm familiar with. 17 | - [x] **Achieved over 3x acceleration with TensorRT**, increasing performance on a Windows 3060 GPU from 28 tokens/s to 110 tokens/s. 18 | - [x] Windows integration package for one-click extraction and use. 19 | - [x] Implemented voice cloning using technologies like LoRA. Please reference [voice_clone](assets/docs/voice_clone.md). 20 | - [ ] Model compression and acceleration using techniques like pruning and knowledge distillation, targeting mobile deployment. 21 | 22 | ### Environment Setup 23 | * Install Python 3; it's recommended to use [Miniforge](https://github.com/conda-forge/miniforge). Run: `conda create -n chattts_plus python=3.10 && conda activate chattts_plus` 24 | * Download the source code: `git clone https://github.com/warmshao/ChatTTSPlus`, and navigate to the project root directory: `cd ChatTTSPlus` 25 | * Install necessary Python libraries: `pip install -r requirements.txt` 26 | * [Optional] If you want to use TensorRT, please install [tensorrt10](https://developer.nvidia.com/tensorrt/download) 27 | * [Recommended for Windows users] Download the integration package directly from [Google Drive Link](https://drive.google.com/file/d/1yOnU5dRTJvFnc4wyw02nAeJH5_FgNod2/view?usp=sharing), extract it, and double-click `webui.bat` to use. If you want to update the code, please double-click `update.bat`. Note: **This will overwrite all your local code modifications.** 28 | 29 | ### Demo 30 | * Web UI with TensorRT: `python webui.py --cfg configs/infer/chattts_plus_trt.yaml`. 31 | * Web UI with PyTorch: `python webui.py --cfg configs/infer/chattts_plus.yaml` 32 | 33 | 34 | 35 | ### License 36 | ChatTTSPlus inherits the license from [ChatTTS](https://github.com/2noise/ChatTTS); please refer to [ChatTTS](https://github.com/2noise/ChatTTS) as the standard. 37 | 38 | The code is published under the AGPLv3+ license. 39 | 40 | The model is published under the CC BY-NC 4.0 license. It is intended for educational and research use and should not be used for any commercial or illegal purposes. The authors do not guarantee the accuracy, completeness, or reliability of the information. The information and data used in this repository are for academic and research purposes only. The data is obtained from publicly available sources, and the authors do not claim any ownership or copyright over the data. 41 | 42 | ### Disclaimer 43 | We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws. -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 | ## ChatTTSPlus: Extension of ChatTTS 2 | 3 | 中文 | English 4 | 5 | ChatTTSPlus是[ChatTTS](https://github.com/2noise/ChatTTS)的扩展,增加使用TensorRT加速、声音克隆和模型移动端运行等功能。 6 | 7 | **如果你觉得这个项目有用,帮我点个star吧✨✨** 8 | 9 | ### 基于ChatTTSPlus做的有趣的demo 10 | * NotebookLM播客: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jz8HPdRe_igoNjMSv0RaTn3l2c3seYFT?usp=sharing) 11 | ,使用ChatTTS把 `AnimateAnyone` 这篇文章变成播客。 12 | 13 | 14 | 15 | ### 新增功能 16 | - [x] 将ChatTTS的代码以我熟悉的方式重构。 17 | - [x] **使用TensorRT实现3倍以上的加速**, 在windows的3060显卡上从28token/s提升到110token/s。 18 | - [x] windows整合包,一键解压使用。 19 | - [x] 使用Lora等技术实现声音克隆。请参考 [声音克隆](assets/docs/voice_clone_ZH.md) 20 | - [ ] 使用剪枝、知识蒸馏等做模型压缩和加速,目标在移动端运行。 21 | 22 | ### 环境安装 23 | * 安装python3,推荐可以用[miniforge](https://github.com/conda-forge/miniforge).`conda create -n chattts_plus python=3.10 && conda activate chattts_plus` 24 | * 下载源码: `git clone https://github.com/warmshao/ChatTTSPlus`, 并到项目根目录下: `cd ChatTTSPlus` 25 | * 安装必要的python库, `pip install -r requirements.txt` 26 | * 【可选】如果你要使用tensorrt的话,请安装[tensorrt10](https://developer.nvidia.com/tensorrt/download) 27 | * 【windows用户推荐】直接从[Google Drive链接](https://drive.google.com/file/d/1yOnU5dRTJvFnc4wyw02nAeJH5_FgNod2/view?usp=sharing)下载整合包,解压后双击`webui.bat`即可使用。如果要更新代码的话,请先双击`update.bat`, 注意:**这会覆盖你本地所有的代码修改**。 28 | 29 | ### Demo 30 | * Webui with TensorRT: `python webui.py --cfg configs/infer/chattts_plus_trt.yaml`. 31 | * Webui with Pytorch: `python webui.py --cfg configs/infer/chattts_plus.yaml` 32 | 33 | 34 | 35 | ### License 36 | ChatTTSPlus继承[ChatTTS](https://github.com/2noise/ChatTTS)的license,请以[ChatTTS](https://github.com/2noise/ChatTTS)为标准。 37 | 38 | The code is published under AGPLv3+ license. 39 | 40 | The model is published under CC BY-NC 4.0 license. It is intended for educational and research use, and should not be used for any commercial or illegal purposes. The authors do not guarantee the accuracy, completeness, or reliability of the information. The information and data used in this repo, are for academic and research purposes only. The data obtained from publicly available sources, and the authors do not claim any ownership or copyright over the data. 41 | 42 | ### 免责声明 43 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规. -------------------------------------------------------------------------------- /assets/docs/voice_clone.md: -------------------------------------------------------------------------------- 1 | ## ChatTTSPlus Voice Cloning 2 | Currently supports two methods of voice cloning: 3 | * Training with lora 4 | * Training speaker embedding 5 | 6 | Note: 7 | * The voice cloning feature of ChatTTSPlus is for learning purposes only. Please do not use it for illegal or criminal activities. We take no responsibility for any illegal use of the codebase. 8 | * Some code references: [ChatTTS PR](https://github.com/2noise/ChatTTS/pull/680) 9 | 10 | ### Data Collection and Preprocessing 11 | * Prepare over 30 minutes of audio from the person you want to clone. 12 | * Process using GPT-SoViTs preprocessing workflow, executing in sequence: audio splitting, UVR5 background separation, noise reduction, speech-to-text, etc. 13 | * Finally get a `.list` file in this format: speaker_name | audio_path | lang | text, like this: 14 | ```text 15 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000000000_0000152640.wav|EN|Hehe, I watched Parasite recently, really recommend it. 16 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000152640_0000323520.wav|EN|The plot is tight and the filming technique is very unique. 17 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000323520_0000474880.wav|EN|It won many awards too, what type of movies do you like? 18 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_2.WAV_10.wav_0000000000_0000114560.wav|EN|I like mystery films, any other recommendations? 19 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_3.WAV_10.wav_0000000000_0000133760.wav|EN|Woof woof woof, then you must watch And Then There Were None. 20 | ``` 21 | 22 | ### Model Training 23 | #### Lora Training (Recommended) 24 | * Modify `DATA/meta_infos` in `configs/train/train_voice_clone_lora.yaml` to the `.list` file processed in previous step 25 | * Modify `exp_name` in `configs/train/train_voice_clone_lora.yaml` to the experiment name, preferably using speaker_name for identification. 26 | * Then run `accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml` to start training. 27 | * Trained models will be saved in the `outputs` folder, like `outputs/xionger_lora-1732809910.2932503/checkpoints/step-900` 28 | * You can visualize training logs using tensorboard, e.g., `tensorboad --logdir=outputs/xionger_lora-1732809910.2932503/tf_logs` 29 | 30 | #### Speaker Embedding Training (Not Recommended, Hard to Converge) 31 | * Modify `DATA/meta_infos` in `configs/train/train_speaker_embedding.yaml` to the `.list` file processed in previous step 32 | * Modify `exp_name` in `configs/train/train_speaker_embedding.yaml` to the experiment name, preferably using speaker_name for identification. 33 | * Then run `accelerate launch train_lora.py --config configs/train/train_speaker_embedding.yaml` to start training. 34 | * Trained speaker embeddings will be saved in the `outputs` folder, like `outputs/xionger_speaker_emb-1732931630.7137222/checkpoints/step-1/xionger.pt` 35 | * You can visualize training logs using tensorboard, e.g., `tensorboad --logdir=outputs/xionger_speaker_emb-1732931630.7137222/tf_logs` 36 | 37 | #### Some Tips 38 | * For better results, it's best to prepare more than 1 hour of audio. I tried training lora with 1 minute of audio, but it was prone to overfitting and the results were mediocre. 39 | * Don't train for too long, otherwise it can easily overfit. When I trained with 1 hour of Lei Jun's audio, it converged between 2000 to 3000 steps. 40 | * If you understand lora training, you can try adjusting the parameters in the config file. 41 | 42 | ### Model Inference 43 | * Launch webui: `python webui.py --cfg configs/infer/chattts_plus.yaml` 44 | * Refer to the following video tutorial for usage: 45 | 46 | -------------------------------------------------------------------------------- /assets/docs/voice_clone_ZH.md: -------------------------------------------------------------------------------- 1 | ## ChatTTSPlus 声音克隆 2 | 目前支持两种声音克隆的方式: 3 | * 训练lora 4 | * 训练speaker embedding 5 | 6 | 注意: 7 | * ChatTTSPlus 的声音克隆功能仅供学习使用,请勿用于非法或犯罪活动。我们对代码库的任何非法使用不承担责任。 8 | * 部分代码参考: [ChatTTS PR](https://github.com/2noise/ChatTTS/pull/680) 9 | 10 | 11 | ### 数据收集和预处理 12 | * 准备想要克隆的某个人的30分钟以上的音频。 13 | * 使用GPT-SoViTs的预处理流程处理,依次执行:音频切分、UVR5背景声分离、语音降噪、语音识别文本等。 14 | * 最后得到这样的`.list`文件: speaker_nam | audio_path | lang | text,类似这样: 15 | ```text 16 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000000000_0000152640.wav|ZH|嘿嘿,最近我看了寄生虫,真的很推荐哦。 17 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000152640_0000323520.wav|ZH|这部电影剧情紧凑,拍摄手法也很独特。 18 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000323520_0000474880.wav|ZH|还得了很多奖项,你有喜欢的电影类型吗? 19 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_2.WAV_10.wav_0000000000_0000114560.wav|ZH|我喜欢悬疑片,有其他推荐吗? 20 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_3.WAV_10.wav_0000000000_0000133760.wav|ZH|汪汪汪,那你一定要看无人生还。 21 | ``` 22 | 23 | ### 模型训练: 24 | #### lora训练(推荐): 25 | * 修改`configs/train/train_voice_clone_lora.yaml`里 `DATA/meta_infos`为上一个处理到的`.list`文件 26 | * 修改`configs/train/train_voice_clone_lora.yaml`里 `exp_name`为实验的名称,最好用speaker_name做区分识别。 27 | * 然后运行`accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml`开始训练。 28 | * 训练的模型会保存在: `outputs` 文件夹下,比如`outputs/xionger_lora-1732809910.2932503/checkpoints/step-900` 29 | * 你可以使用tensorboard可视化训练的log, 比如`tensorboad --logdir=outputs/xionger_lora-1732809910.2932503/tf_logs` 30 | 31 | #### speaker embedding训练(不推荐,很难收敛): 32 | * 修改`configs/train/train_speaker_embedding.yaml`里 `DATA/meta_infos`为上一个处理到的`.list`文件 33 | * 修改`configs/train/train_speaker_embedding.yaml`里 `exp_name`为实验的名称,最好用speaker_name做区分识别。 34 | * 然后运行`accelerate launch train_lora.py --config configs/train/train_speaker_embedding.yaml`开始训练。 35 | * 训练的speaker embedding 会保存在: `outputs` 文件夹下,比如`outputs/xionger_speaker_emb-1732931630.7137222/checkpoints/step-1/xionger.pt` 36 | * 你可以使用tensorboard可视化训练的log, 比如`tensorboad --logdir=outputs/xionger_speaker_emb-1732931630.7137222/tf_logs` 37 | 38 | #### 一些Tips 39 | * 如果要效果好的话,最好准备1小时以上的音频。我试过1分钟训练lora,但是比较容易过拟合,效果一般。 40 | * 不要训练太久,不然容易过拟合,我用1小时的雷军音频训练,2000到3000 step就收敛了。 41 | * 如果你懂lora训练的话,你可以尝试调整config里面的参数。 42 | 43 | ### 模型推理 44 | * 启动webui: ` python webui.py --cfg configs/infer/chattts_plus.yaml` 45 | * 参考以下视频教程使用: 46 | 47 | 48 | -------------------------------------------------------------------------------- /assets/speakers/2222.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/assets/speakers/2222.pt -------------------------------------------------------------------------------- /chattts_plus/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/27 22:47 3 | # @Project : ChatTTSPlus 4 | # @FileName: __init__.py.py 5 | -------------------------------------------------------------------------------- /chattts_plus/commons/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/9/22 11:48 3 | # @Project : ChatTTSPlus 4 | # @FileName: __init__.py.py 5 | -------------------------------------------------------------------------------- /chattts_plus/commons/constants.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/10/8 10:37 3 | # @Author : wenshao 4 | # @ProjectName: ChatTTSPlus 5 | # @FileName: constants.py 6 | 7 | import os 8 | 9 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 10 | PROJECT_DIR = os.environ.get("CHATTTS_PLUS_PROJECT_DIR", os.path.join(CURRENT_DIR, '..', '..')) 11 | PROJECT_DIR = os.path.abspath(PROJECT_DIR) 12 | CHECKPOINT_DIR = os.environ.get("CHATTTS_PLUS_CHECKPOINT_DIR", 13 | os.path.abspath(os.path.join(CURRENT_DIR, '..', 'checkpoints'))) 14 | CHECKPOINT_DIR = os.path.abspath(CHECKPOINT_DIR) 15 | LOG_DIR = os.environ.get("CHATTTS_PLUS_LOG_DIR", os.path.join(PROJECT_DIR, 'logs')) 16 | LOG_DIR = os.path.abspath(LOG_DIR) 17 | -------------------------------------------------------------------------------- /chattts_plus/commons/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import logging 4 | from datetime import datetime, timezone 5 | from pathlib import Path 6 | from .constants import LOG_DIR 7 | 8 | logging.getLogger("numba").setLevel(logging.WARNING) 9 | logging.getLogger("httpx").setLevel(logging.WARNING) 10 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING) 11 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING) 12 | 13 | # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96 14 | colorCodePanic = "\x1b[1;31m" 15 | colorCodeFatal = "\x1b[1;31m" 16 | colorCodeError = "\x1b[31m" 17 | colorCodeWarn = "\x1b[33m" 18 | colorCodeInfo = "\x1b[37m" 19 | colorCodeDebug = "\x1b[32m" 20 | colorCodeTrace = "\x1b[36m" 21 | colorReset = "\x1b[0m" 22 | 23 | log_level_color_code = { 24 | logging.DEBUG: colorCodeDebug, 25 | logging.INFO: colorCodeInfo, 26 | logging.WARN: colorCodeWarn, 27 | logging.ERROR: colorCodeError, 28 | logging.FATAL: colorCodeFatal, 29 | } 30 | 31 | log_level_msg_str = { 32 | logging.DEBUG: "DEBU", 33 | logging.INFO: "INFO", 34 | logging.WARN: "WARN", 35 | logging.ERROR: "ERRO", 36 | logging.FATAL: "FATL", 37 | } 38 | 39 | 40 | class Formatter(logging.Formatter): 41 | def __init__(self, color=platform.system().lower() != "windows"): 42 | # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone 43 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo 44 | self.color = color 45 | 46 | def format(self, record: logging.LogRecord): 47 | logstr = "[" + datetime.now(self.tz).strftime("%Y%m%d %H:%M:%S") + "] [" 48 | if self.color: 49 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo) 50 | logstr += log_level_msg_str.get(record.levelno, record.levelname) 51 | if self.color: 52 | logstr += colorReset 53 | fn = record.filename.removesuffix(".py") 54 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg) % record.args}" 55 | return logstr 56 | 57 | 58 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False, log_file=None): 59 | """ 60 | Configure and return a logger with specified settings. 61 | 62 | Args: 63 | name (str): The name of the logger. 64 | lv (int): The logging level (default: logging.INFO). 65 | remove_exist (bool): Whether to remove existing handlers (default: False). 66 | format_root (bool): Whether to format the root logger as well (default: False). 67 | log_file (str | Path | None): Path to the log file. If provided, logs will also be written to this file. 68 | 69 | Returns: 70 | logging.Logger: Configured logger instance. 71 | """ 72 | logger = logging.getLogger(name) 73 | logger.setLevel(lv) 74 | 75 | if remove_exist and logger.hasHandlers(): 76 | logger.handlers.clear() 77 | 78 | formatter = Formatter() 79 | 80 | # Add console handler if no handlers exist 81 | if not logger.hasHandlers(): 82 | console_handler = logging.StreamHandler() 83 | console_handler.setFormatter(formatter) 84 | logger.addHandler(console_handler) 85 | else: 86 | # Update formatter for existing handlers 87 | for h in logger.handlers: 88 | h.setFormatter(formatter) 89 | 90 | # Add file handler if log_file is specified 91 | if log_file is None: 92 | os.makedirs(LOG_DIR, exist_ok=True) 93 | date_str = datetime.now().strftime("%y%m%d") 94 | log_file = os.path.join(LOG_DIR, f"chattts_plus_{date_str}.log") 95 | 96 | log_file = Path(log_file) 97 | # Create directory if it doesn't exist 98 | log_file.parent.mkdir(parents=True, exist_ok=True) 99 | 100 | file_handler = logging.FileHandler(str(log_file), encoding='utf-8') 101 | file_handler.setFormatter(formatter) 102 | logger.addHandler(file_handler) 103 | 104 | if format_root: 105 | for h in logger.root.handlers: 106 | h.setFormatter(formatter) 107 | 108 | return logger 109 | -------------------------------------------------------------------------------- /chattts_plus/commons/norm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from typing import Dict, Tuple, List, Literal, Callable, Optional 5 | import sys 6 | 7 | from numba import jit 8 | import numpy as np 9 | from . import logger as logger_ 10 | 11 | 12 | @jit 13 | def _find_index(table: np.ndarray, val: np.uint16): 14 | for i in range(table.size): 15 | if table[i] == val: 16 | return i 17 | return -1 18 | 19 | 20 | @jit 21 | def _fast_replace( 22 | table: np.ndarray, text: bytes 23 | ) -> Tuple[np.ndarray, List[Tuple[str, str]]]: 24 | result = np.frombuffer(text, dtype=np.uint16).copy() 25 | replaced_words = [] 26 | for i in range(result.size): 27 | ch = result[i] 28 | p = _find_index(table[0], ch) 29 | if p >= 0: 30 | repl_char = table[1][p] 31 | result[i] = repl_char 32 | replaced_words.append((chr(ch), chr(repl_char))) 33 | return result, replaced_words 34 | 35 | 36 | class Normalizer: 37 | def __init__(self, map_file_path: str, logger=None): 38 | if logger is None: 39 | logger = logger_.get_logger(self.__class__.__name__) 40 | self.logger = logger 41 | self.normalizers: Dict[str, Callable[[str], str]] = {} 42 | self.homophones_map = self._load_homophones_map(map_file_path) 43 | """ 44 | homophones_map 45 | 46 | Replace the mispronounced characters with correctly pronounced ones. 47 | 48 | Creation process of homophones_map.json: 49 | 50 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text. 51 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words. 52 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS. 53 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping. 54 | 55 | Thanks to: 56 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html) 57 | [python-pinyin](https://github.com/mozillazg/python-pinyin) 58 | 59 | """ 60 | self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" 61 | self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]") 62 | self.sub_pattern = re.compile(r"\[uv_break\]|\[laugh\]|\[lbreak\]") 63 | self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") 64 | self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b") 65 | self.character_simplifier = str.maketrans( 66 | { 67 | ":": ",", 68 | ";": ",", 69 | "!": "。", 70 | "(": ",", 71 | ")": ",", 72 | "【": ",", 73 | "】": ",", 74 | "『": ",", 75 | "』": ",", 76 | "「": ",", 77 | "」": ",", 78 | "《": ",", 79 | "》": ",", 80 | "-": ",", 81 | ":": ",", 82 | ";": ",", 83 | "!": ".", 84 | "(": ",", 85 | ")": ",", 86 | # "[": ",", 87 | # "]": ",", 88 | ">": ",", 89 | "<": ",", 90 | "-": ",", 91 | } 92 | ) 93 | self.halfwidth_2_fullwidth = str.maketrans( 94 | { 95 | "!": "!", 96 | '"': "“", 97 | "'": "‘", 98 | "#": "#", 99 | "$": "$", 100 | "%": "%", 101 | "&": "&", 102 | "(": "(", 103 | ")": ")", 104 | ",": ",", 105 | "-": "-", 106 | "*": "*", 107 | "+": "+", 108 | ".": "。", 109 | "/": "/", 110 | ":": ":", 111 | ";": ";", 112 | "<": "<", 113 | "=": "=", 114 | ">": ">", 115 | "?": "?", 116 | "@": "@", 117 | # '[': '[', 118 | "\\": "\", 119 | # ']': ']', 120 | "^": "^", 121 | # '_': '_', 122 | "`": "`", 123 | "{": "{", 124 | "|": "|", 125 | "}": "}", 126 | "~": "~", 127 | } 128 | ) 129 | 130 | def __call__( 131 | self, 132 | text: str, 133 | do_text_normalization=True, 134 | do_homophone_replacement=True, 135 | lang: Optional[Literal["zh", "en"]] = None, 136 | ) -> str: 137 | if do_text_normalization: 138 | _lang = self._detect_language(text) if lang is None else lang 139 | if _lang in self.normalizers: 140 | text = self.normalizers[_lang](text) 141 | if _lang == "zh": 142 | text = self._apply_half2full_map(text) 143 | invalid_characters = self._count_invalid_characters(text) 144 | if len(invalid_characters): 145 | self.logger.debug(f"found invalid characters: {invalid_characters}") 146 | text = self._apply_character_map(text) 147 | if do_homophone_replacement: 148 | arr, replaced_words = _fast_replace( 149 | self.homophones_map, 150 | text.encode(self.coding), 151 | ) 152 | if replaced_words: 153 | text = arr.tobytes().decode(self.coding) 154 | repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words]) 155 | self.logger.debug(f"replace homophones: {repl_res}") 156 | if len(invalid_characters): 157 | text = self.reject_pattern.sub("", text) 158 | return text 159 | 160 | def register(self, name: str, normalizer: Callable[[str], str]) -> bool: 161 | if name in self.normalizers: 162 | self.logger.warning(f"name {name} has been registered") 163 | return False 164 | try: 165 | val = normalizer("test string 测试字符串") 166 | if not isinstance(val, str): 167 | self.logger.warning("normalizer must have caller type (str) -> str") 168 | return False 169 | except Exception as e: 170 | self.logger.warning(e) 171 | return False 172 | self.normalizers[name] = normalizer 173 | return True 174 | 175 | def unregister(self, name: str): 176 | if name in self.normalizers: 177 | del self.normalizers[name] 178 | 179 | def destroy(self): 180 | del self.homophones_map 181 | 182 | def _load_homophones_map(self, map_file_path: str) -> np.ndarray: 183 | with open(map_file_path, "r", encoding="utf-8") as f: 184 | homophones_map: Dict[str, str] = json.load(f) 185 | map = np.empty((2, len(homophones_map)), dtype=np.uint32) 186 | for i, k in enumerate(homophones_map.keys()): 187 | map[:, i] = (ord(k), ord(homophones_map[k])) 188 | del homophones_map 189 | return map 190 | 191 | def _count_invalid_characters(self, s: str): 192 | s = self.sub_pattern.sub("", s) 193 | non_alphabetic_chinese_chars = self.reject_pattern.findall(s) 194 | return set(non_alphabetic_chinese_chars) 195 | 196 | def _apply_half2full_map(self, text: str) -> str: 197 | return text.translate(self.halfwidth_2_fullwidth) 198 | 199 | def _apply_character_map(self, text: str) -> str: 200 | return text.translate(self.character_simplifier) 201 | 202 | def _detect_language(self, sentence: str) -> Literal["zh", "en"]: 203 | chinese_chars = self.chinese_char_pattern.findall(sentence) 204 | english_words = self.english_word_pattern.findall(sentence) 205 | 206 | if len(chinese_chars) > len(english_words): 207 | return "zh" 208 | else: 209 | return "en" 210 | -------------------------------------------------------------------------------- /chattts_plus/commons/onnx2trt.py: -------------------------------------------------------------------------------- 1 | # 2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 3 | # SPDX-License-Identifier: Apache-2.0 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | 18 | import os 19 | import pdb 20 | import sys 21 | import logging 22 | import argparse 23 | 24 | import tensorrt as trt 25 | from polygraphy.backend.trt import ( 26 | engine_from_bytes, 27 | engine_from_network, 28 | network_from_onnx_path, 29 | save_engine, 30 | ) 31 | 32 | logging.basicConfig(level=logging.INFO) 33 | logging.getLogger("EngineBuilder").setLevel(logging.INFO) 34 | log = logging.getLogger("EngineBuilder") 35 | 36 | 37 | class EngineBuilder: 38 | """ 39 | Parses an ONNX graph and builds a TensorRT engine from it. 40 | """ 41 | 42 | def __init__(self, verbose=False): 43 | """ 44 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger. 45 | """ 46 | self.trt_logger = trt.Logger(trt.Logger.INFO) 47 | if verbose: 48 | self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE 49 | 50 | trt.init_libnvinfer_plugins(self.trt_logger, namespace="") 51 | 52 | self.builder = trt.Builder(self.trt_logger) 53 | self.config = self.builder.create_builder_config() 54 | # self.config.max_workspace_size = 24 * (2 ** 30) # 12 GB 55 | 56 | profile = self.builder.create_optimization_profile() 57 | 58 | # GPT 59 | profile.set_shape("inputs_embeds", min=[1, 1, 768], opt=[2, 1, 768], max=[4, 2048, 768]) 60 | profile.set_shape("attention_mask", min=[1, 1], opt=[2, 256], max=[4, 2048]) 61 | profile.set_shape("position_ids", min=[1, 1], opt=[2, 1], max=[4, 2048]) 62 | profile.set_shape("past_key_values", min=[20, 2, 1, 12, 1, 64], opt=[20, 2, 2, 12, 256, 64], 63 | max=[20, 2, 4, 12, 2048, 64]) 64 | 65 | self.config.add_optimization_profile(profile) 66 | 67 | self.batch_size = None 68 | self.network = None 69 | self.parser = None 70 | 71 | def create_network(self, onnx_path): 72 | """ 73 | Parse the ONNX graph and create the corresponding TensorRT network definition. 74 | :param onnx_path: The path to the ONNX graph to load. 75 | """ 76 | 77 | onnx_path = os.path.realpath(onnx_path) 78 | self.network_infos = self.builder, self.network, _ = network_from_onnx_path( 79 | onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] 80 | ) 81 | 82 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)] 83 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)] 84 | log.info("Network Description") 85 | for input in inputs: 86 | self.batch_size = input.shape[0] 87 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype)) 88 | for output in outputs: 89 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype)) 90 | 91 | def create_engine( 92 | self, 93 | engine_path, 94 | precision 95 | ): 96 | """ 97 | Build the TensorRT engine and serialize it to disk. 98 | :param engine_path: The path where to serialize the engine to. 99 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'. 100 | """ 101 | engine_path = os.path.realpath(engine_path) 102 | engine_dir = os.path.dirname(engine_path) 103 | os.makedirs(engine_dir, exist_ok=True) 104 | log.info("Building {} Engine in {}".format(precision, engine_path)) 105 | 106 | if precision == "fp16": 107 | if not self.builder.platform_has_fast_fp16: 108 | log.warning("FP16 is not supported natively on this platform/device") 109 | else: 110 | self.config.set_flag(trt.BuilderFlag.FP16) 111 | elif precision == "int8": 112 | if not self.builder.platform_has_fast_int8: 113 | log.warning("Int8 is not supported natively on this platform/device") 114 | else: 115 | self.config.set_flag(trt.BuilderFlag.INT8) 116 | 117 | try: 118 | engine = engine_from_network( 119 | self.network_infos, 120 | self.config 121 | ) 122 | except Exception as e: 123 | print(f"Failed to build engine: {e}") 124 | return 1 125 | try: 126 | save_engine(engine, path=engine_path) 127 | except Exception as e: 128 | print(f"Failed to save engine: {e}") 129 | return 1 130 | return 0 131 | 132 | 133 | def convert_onnx_to_trt(onnx_path, trt_path, verbose=False, precision="fp16"): 134 | builder = EngineBuilder(verbose) 135 | builder.create_network(onnx_path) 136 | builder.create_engine( 137 | trt_path, 138 | precision 139 | ) 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load") 145 | parser.add_argument("-e", "--engine", help="The output path for the TRT engine") 146 | parser.add_argument( 147 | "-p", 148 | "--precision", 149 | default="fp16", 150 | choices=["fp32", "fp16", "int8"], 151 | help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'", 152 | ) 153 | parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output") 154 | args = parser.parse_args() 155 | if args.engine is None: 156 | args.engine = args.onnx.replace(".onnx", ".trt") 157 | convert_onnx_to_trt(args.onnx, args.engine, args.verbose, args.precision) 158 | -------------------------------------------------------------------------------- /chattts_plus/commons/text_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | # ref: https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization 3 | from zh_normalization import TextNormalizer 4 | from functools import partial 5 | 6 | 7 | # 数字转为英文读法 8 | def num_to_english(num): 9 | num_str = str(num) 10 | # English representations for numbers 0-9 11 | english_digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] 12 | units = ["", "ten", "hundred", "thousand"] 13 | big_units = ["", "thousand", "million", "billion", "trillion"] 14 | result = "" 15 | need_and = False # Indicates whether 'and' needs to be added 16 | part = [] # Stores each group of 4 digits 17 | is_first_part = True # Indicates if it is the first part for not adding 'and' at the beginning 18 | 19 | # Split the number into 3-digit groups 20 | while num_str: 21 | part.append(num_str[-3:]) 22 | num_str = num_str[:-3] 23 | 24 | part.reverse() 25 | 26 | for i, p in enumerate(part): 27 | p_str = "" 28 | digit_len = len(p) 29 | if int(p) == 0 and i < len(part) - 1: 30 | continue 31 | 32 | hundreds_digit = int(p) // 100 if digit_len == 3 else None 33 | tens_digit = int(p) % 100 if digit_len >= 2 else int(p[0] if digit_len == 1 else p[1]) 34 | 35 | # Process hundreds 36 | if hundreds_digit is not None and hundreds_digit != 0: 37 | p_str += english_digits[hundreds_digit] + " hundred" 38 | if tens_digit != 0: 39 | p_str += " and " 40 | 41 | # Process tens and ones 42 | if 10 < tens_digit < 20: # Teens exception 43 | teen_map = { 44 | 11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 15: "fifteen", 45 | 16: "sixteen", 17: "seventeen", 18: "eighteen", 19: "nineteen" 46 | } 47 | p_str += teen_map[tens_digit] 48 | else: 49 | tens_map = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"] 50 | tens_val = tens_digit // 10 51 | ones_val = tens_digit % 10 52 | if tens_val >= 2: 53 | p_str += tens_map[tens_val] + (" " + english_digits[ones_val] if ones_val != 0 else "") 54 | elif tens_digit != 0 and tens_val < 2: # When tens_digit is in [1, 9] 55 | p_str += english_digits[tens_digit] 56 | 57 | if p_str and not is_first_part and need_and: 58 | result += " and " 59 | result += p_str 60 | if i < len(part) - 1 and int(p) != 0: 61 | result += " " + big_units[len(part) - i - 1] + ", " 62 | 63 | is_first_part = False 64 | if int(p) != 0: 65 | need_and = True 66 | 67 | return result.capitalize() 68 | 69 | 70 | def get_lang(text): 71 | # 定义中文标点符号的模式 72 | chinese_punctuation = "[。?!,、;:‘’“”()《》【】…—\u3000]" 73 | # 使用正则表达式替换所有中文标点为"" 74 | cleaned_text = re.sub(chinese_punctuation, "", text) 75 | # 使用正则表达式来匹配中文字符范围 76 | return "zh" if re.search('[\u4e00-\u9fff]', cleaned_text) is not None else "en" 77 | 78 | 79 | def fraction_to_words(match): 80 | numerator, denominator = match.groups() 81 | # 这里只是把数字直接拼接成了英文分数的形式, 实际上应该使用某种方式将数字转换为英文单词 82 | # 例如: "1/2" -> "one half", 这里仅为展示目的而直接返回了 "numerator/denominator" 83 | return numerator + " over " + denominator 84 | 85 | 86 | # 数字转为英文读法 87 | def num2text(text): 88 | numtext = [' zero ', ' one ', ' two ', ' three ', ' four ', ' five ', ' six ', ' seven ', ' eight ', ' nine '] 89 | point = ' point ' 90 | text = re.sub(r'(\d)\,(\d)', r'\1\2', text) 91 | text = re.sub(r'(\d+)\s*\+', r'\1 plus ', text) 92 | text = re.sub(r'(\d+)\s*\-', r'\1 minus ', text) 93 | text = re.sub(r'(\d+)\s*[\*x]', r'\1 times ', text) 94 | text = re.sub(r'((?:\d+\.)?\d+)\s*/\s*(\d+)', fraction_to_words, text) 95 | 96 | # 取出数字 number_list= [('1000200030004000.123', '1000200030004000', '123'), ('23425', '23425', '')] 97 | number_list = re.findall(r'((\d+)(?:\.(\d+))?%?)', text) 98 | if len(number_list) > 0: 99 | # dc= ('1000200030004000.123', '1000200030004000', '123','') 100 | for m, dc in enumerate(number_list): 101 | if len(dc[1]) > 16: 102 | continue 103 | int_text = num_to_english(dc[1]) 104 | if len(dc) > 2 and dc[2]: 105 | int_text += point + "".join([numtext[int(i)] for i in dc[2]]) 106 | if dc[0][-1] == '%': 107 | int_text = f' the pronunciation of {int_text}' 108 | text = text.replace(dc[0], int_text) 109 | 110 | return text.replace('1', ' one ').replace('2', ' two ').replace('3', ' three ').replace('4', ' four ').replace('5', 111 | ' five ').replace( 112 | '6', ' six ').replace('7', 'seven').replace('8', ' eight ').replace('9', ' nine ').replace('0', 113 | ' zero ').replace( 114 | '=', ' equals ') 115 | 116 | 117 | def remove_brackets(text): 118 | # 正则表达式 119 | text = re.sub(r'\[(uv_break|laugh|lbreak|break)\]', r' \1 ', text, re.I | re.S | re.M) 120 | 121 | # 使用 re.sub 替换掉 [ ] 对 122 | newt = re.sub(r'\[|\]|!|:|{|}', '', text) 123 | return re.sub(r'\s(uv_break|laugh|lbreak|break)(?=\s|$)', r' [\1] ', newt) 124 | 125 | 126 | # 中英文数字转换为文字,特殊符号处理 127 | def split_text(text_list): 128 | tx = TextNormalizer() 129 | haserror = False 130 | result = [] 131 | for i, text in enumerate(text_list): 132 | text = remove_brackets(text) 133 | if get_lang(text) == 'zh': 134 | tmp = "".join(tx.normalize(text)) 135 | elif haserror: 136 | tmp = num2text(text) 137 | else: 138 | try: 139 | # 先尝试使用 nemo_text_processing 处理英文 140 | from nemo_text_processing.text_normalization.normalize import Normalizer 141 | fun = partial(Normalizer(input_case='cased', lang="en").normalize, verbose=False, 142 | punct_post_process=True) 143 | tmp = fun(text) 144 | print(f'使用nemo处理英文ok') 145 | except Exception as e: 146 | print(f"nemo处理英文失败,改用自定义预处理") 147 | print(e) 148 | haserror = True 149 | tmp = num2text(text) 150 | 151 | if len(tmp) > 200: 152 | tmp_res = split_text_by_punctuation(tmp) 153 | result = result + tmp_res 154 | else: 155 | result.append(tmp) 156 | return result 157 | 158 | 159 | # 切分长行 200 150 160 | def split_text_by_punctuation(text): 161 | # 定义长度限制 162 | min_length = 150 163 | punctuation_marks = "。?!,、;:”’》」』)】…—" 164 | english_punctuation = ".?!,:;)}…" 165 | 166 | # 结果列表 167 | result = [] 168 | # 起始位置 169 | pos = 0 170 | 171 | # 遍历文本中的每个字符 172 | text_length = len(text) 173 | for i, char in enumerate(text): 174 | if char in punctuation_marks or char in english_punctuation: 175 | if char == '.' and i < text_length - 1 and re.match(r'\d', text[i + 1]): 176 | continue 177 | # 当遇到标点时,判断当前分段长度是否超过120 178 | if i - pos > min_length: 179 | # 如果长度超过120,将当前分段添加到结果列表中 180 | result.append(text[pos:i + 1]) 181 | # 更新起始位置到当前标点的下一个字符 182 | pos = i + 1 183 | # print(f'{pos=},{len(text)=}') 184 | 185 | # 如果剩余文本长度超过120或没有更多标点符号可以进行分割,将剩余的文本作为一个分段添加到结果列表 186 | if pos < len(text): 187 | result.append(text[pos:]) 188 | 189 | return result 190 | -------------------------------------------------------------------------------- /chattts_plus/commons/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/10/8 10:36 3 | # @Author : wenshao 4 | # @ProjectName: ChatTTSPlus 5 | # @FileName: utils.py 6 | 7 | import torch 8 | from dataclasses import dataclass, asdict 9 | from typing import Literal, Optional, List, Tuple, Dict, Union 10 | 11 | 12 | @dataclass(repr=False, eq=False) 13 | class RefineTextParams: 14 | prompt: str = "" 15 | top_P: float = 0.7 16 | top_K: int = 20 17 | temperature: float = 0.7 18 | repetition_penalty: float = 1.0 19 | max_new_token: int = 384 20 | min_new_token: int = 0 21 | show_tqdm: bool = True 22 | ensure_non_empty: bool = True 23 | 24 | 25 | @dataclass(repr=False, eq=False) 26 | class InferCodeParams(RefineTextParams): 27 | prompt: str = "[speed_5]" 28 | spk_emb: Optional[str] = None 29 | spk_smp: Optional[str] = None 30 | txt_smp: Optional[str] = None 31 | temperature: float = 0.3 32 | repetition_penalty: float = 1.05 33 | max_new_token: int = 2048 34 | stream_batch: int = 24 35 | stream_speed: int = 12000 36 | pass_first_n_batches: int = 2 37 | 38 | 39 | def get_inference_device(): 40 | if torch.cuda.is_available(): 41 | return torch.device("cuda") 42 | elif torch.backends.mps.is_available(): 43 | return torch.device("mps") 44 | else: 45 | return torch.device("cpu") 46 | 47 | 48 | class TorchSeedContext: 49 | def __init__(self, seed): 50 | self.seed = seed 51 | self.state = None 52 | 53 | def __enter__(self): 54 | self.state = torch.random.get_rng_state() 55 | torch.manual_seed(self.seed) 56 | 57 | def __exit__(self, type, value, traceback): 58 | torch.random.set_rng_state(self.state) 59 | -------------------------------------------------------------------------------- /chattts_plus/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/10/7 12:14 3 | # @Author : wenshao 4 | # @ProjectName: ChatTTSPlus 5 | # @FileName: __init__.py.py 6 | -------------------------------------------------------------------------------- /chattts_plus/datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/23 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: base_dataset.py 7 | 8 | import os.path 9 | import pdb 10 | 11 | import torch 12 | import torchaudio 13 | from torch.utils.data import Dataset 14 | 15 | 16 | class BaseDataset(Dataset): 17 | """ 18 | 设置一个通用的 train dataset loader 19 | """ 20 | 21 | def __init__(self, meta_infos=[], 22 | tokenizer=None, 23 | normalizer=None, 24 | sample_rate=24_000, 25 | num_vq=4, 26 | use_empty_speaker=False, 27 | **kwargs 28 | ): 29 | super(BaseDataset, self).__init__() 30 | self.meta_infos = meta_infos 31 | self.tokenizer = tokenizer 32 | self.normalizer = normalizer 33 | self.sample_rate = sample_rate 34 | self.num_vq = num_vq 35 | self.use_empty_speaker = use_empty_speaker 36 | self.data_infos, self.speakers = self.load_data() 37 | 38 | def load_data(self, **kwargs): 39 | data_infos = [] 40 | speakers = set() 41 | for info_path in self.meta_infos: 42 | try: 43 | with open(info_path, "r", encoding='UTF-8') as fin: 44 | for line_num, line in enumerate(fin.readlines(), 1): 45 | line = line.strip().replace('\n', '') 46 | if not line: # 跳过空行 47 | continue 48 | 49 | line_splits = line.split("|") 50 | if len(line_splits) < 4: # 确保有足够的字段 51 | print(f"警告: 第 {line_num} 行格式不正确: {line}") 52 | continue 53 | 54 | try: 55 | data_info = { 56 | "speaker": line_splits[0], 57 | "audio_path": line_splits[1], 58 | "text": line_splits[-1], 59 | "lang": line_splits[-2].lower() 60 | } 61 | # 验证音频文件是否存在 62 | if not os.path.exists(data_info["audio_path"]): 63 | print(f"警告: 音频文件不存在: {data_info['audio_path']}") 64 | continue 65 | 66 | speakers.add(data_info["speaker"]) 67 | data_infos.append(data_info) 68 | except Exception as e: 69 | print(f"错误: 处理第 {line_num} 行时出错: {str(e)}") 70 | print(f"行内容: {line}") 71 | continue 72 | 73 | except Exception as e: 74 | print(f"错误: 读取文件 {info_path} 时出错: {str(e)}") 75 | continue 76 | 77 | if not data_infos: 78 | raise ValueError(f"没有找到有效的训练数据。请检查数据文件: {self.meta_infos}") 79 | 80 | return data_infos, speakers 81 | 82 | def __len__(self): 83 | return len(self.data_infos) 84 | 85 | def __getitem__(self, i): 86 | data_info_ = self.data_infos[i] 87 | audio_wavs, audio_mask = self.preprocess_audio(data_info_["audio_path"]) 88 | text_input_ids, text_mask = self.preprocess_text(data_info_["text"], data_info_["lang"]) 89 | return { 90 | "speaker": data_info_["speaker"], 91 | "text": data_info_["text"], 92 | "audio_wavs": audio_wavs, 93 | "audio_mask": audio_mask, 94 | "text_input_ids": text_input_ids, 95 | "text_mask": text_mask 96 | } 97 | 98 | def preprocess_text( 99 | self, 100 | text, 101 | lang="zh", 102 | do_text_normalization=True, 103 | do_homophone_replacement=True, 104 | ): 105 | 106 | text = self.normalizer( 107 | text, 108 | do_text_normalization, 109 | do_homophone_replacement, 110 | lang, 111 | ) 112 | if self.use_empty_speaker: 113 | text = f'[Stts][empty_spk]{text}[Ptts]' 114 | else: 115 | text = f'[Stts][spk_emb]{text}[Ptts]' 116 | input_ids, attention_mask, text_mask = self.tokenizer.encode([text], num_vq=self.num_vq) 117 | return input_ids.squeeze(0), text_mask.squeeze(0) 118 | 119 | def preprocess_audio(self, audio_path): 120 | # 如果是相对路径,转换为绝对路径 121 | if not os.path.isabs(audio_path): 122 | audio_path = os.path.join(os.path.dirname(self.meta_info), audio_path) 123 | 124 | # 确保文件存在 125 | if not os.path.exists(audio_path): 126 | raise FileNotFoundError(f"Audio file not found: {audio_path}") 127 | 128 | audio_wavs, sample_rate = torchaudio.load(audio_path) 129 | if sample_rate != self.sample_rate: 130 | audio_wavs = torchaudio.functional.resample( 131 | audio_wavs, 132 | orig_freq=sample_rate, 133 | new_freq=self.sample_rate, 134 | ) 135 | audio_wavs = audio_wavs.mean(0) 136 | audio_mask = torch.ones(len(audio_wavs)) 137 | return audio_wavs, audio_mask 138 | -------------------------------------------------------------------------------- /chattts_plus/datasets/collator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/23 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: collator.py 7 | import pdb 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | class BaseCollator: 14 | def __init__(self, text_pad: int = 0, audio_pad: int = 0): 15 | self.text_pad = text_pad 16 | self.audio_pad = audio_pad 17 | 18 | def __call__(self, batch): 19 | batch = [x for x in batch if x is not None] 20 | 21 | audio_maxlen = max(len(item["audio_wavs"]) for item in batch) 22 | text_maxlen = max(len(item["text_input_ids"]) for item in batch) 23 | 24 | speaker = [] 25 | text = [] 26 | text_input_ids = [] 27 | text_mask = [] 28 | audio_wavs = [] 29 | audio_mask = [] 30 | for x in batch: 31 | text.append(x["text"]) 32 | speaker.append(x["speaker"]) 33 | text_input_ids.append( 34 | F.pad( 35 | x["text_input_ids"], 36 | (0, 0, text_maxlen - len(x["text_input_ids"]), 0), 37 | value=self.text_pad, 38 | ) 39 | ) 40 | text_mask.append( 41 | F.pad( 42 | x["text_mask"], 43 | (text_maxlen - len(x["text_mask"]), 0), 44 | value=0, 45 | ) 46 | ) 47 | audio_wavs.append( 48 | F.pad( 49 | x["audio_wavs"], 50 | (0, audio_maxlen - len(x["audio_wavs"])), 51 | value=self.audio_pad, 52 | ) 53 | ) 54 | audio_mask.append( 55 | F.pad( 56 | x["audio_mask"], 57 | (0, audio_maxlen - len(x["audio_mask"])), 58 | value=0, 59 | ) 60 | ) 61 | return { 62 | "speaker": speaker, 63 | "text": text, 64 | "text_input_ids": torch.stack(text_input_ids), 65 | "text_mask": torch.stack(text_mask), 66 | "audio_wavs": torch.stack(audio_wavs), 67 | "audio_mask": torch.stack(audio_mask), 68 | } 69 | -------------------------------------------------------------------------------- /chattts_plus/models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/27 22:53 3 | # @Project : ChatTTSPlus 4 | # @FileName: __init__.py.py 5 | 6 | from .tokenizer import Tokenizer 7 | from .gpt import GPT 8 | from .dvae import DVAE 9 | -------------------------------------------------------------------------------- /chattts_plus/models/dvae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pdb 3 | from typing import List, Optional, Literal, Tuple 4 | 5 | import numpy as np 6 | import pybase16384 as b14 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchaudio 11 | from vector_quantize_pytorch import GroupedResidualFSQ 12 | 13 | from ..commons import logger 14 | 15 | 16 | class ConvNeXtBlock(nn.Module): 17 | def __init__( 18 | self, 19 | dim: int, 20 | intermediate_dim: int, 21 | kernel: int, 22 | dilation: int, 23 | layer_scale_init_value: float = 1e-6, 24 | ): 25 | # ConvNeXt Block copied from Vocos. 26 | super().__init__() 27 | self.dwconv = nn.Conv1d( 28 | dim, 29 | dim, 30 | kernel_size=kernel, 31 | padding=dilation * (kernel // 2), 32 | dilation=dilation, 33 | groups=dim, 34 | ) # depthwise conv 35 | 36 | self.norm = nn.LayerNorm(dim, eps=1e-6) 37 | self.pwconv1 = nn.Linear( 38 | dim, intermediate_dim 39 | ) # pointwise/1x1 convs, implemented with linear layers 40 | self.act = nn.GELU() 41 | self.pwconv2 = nn.Linear(intermediate_dim, dim) 42 | self.gamma = ( 43 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) 44 | if layer_scale_init_value > 0 45 | else None 46 | ) 47 | 48 | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: 49 | residual = x 50 | 51 | y = self.dwconv(x) 52 | y.transpose_(1, 2) # (B, C, T) -> (B, T, C) 53 | x = self.norm(y) 54 | y = self.pwconv1(x) 55 | x = self.act(y) 56 | y = self.pwconv2(x) 57 | if self.gamma is not None: 58 | y *= self.gamma 59 | y.transpose_(1, 2) # (B, T, C) -> (B, C, T) 60 | 61 | x = y + residual 62 | 63 | return x 64 | 65 | 66 | class GFSQ(nn.Module): 67 | 68 | def __init__( 69 | self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True 70 | ): 71 | super(GFSQ, self).__init__() 72 | self.quantizer = GroupedResidualFSQ( 73 | dim=dim, 74 | levels=list(levels), 75 | num_quantizers=R, 76 | groups=G, 77 | ) 78 | self.n_ind = math.prod(levels) 79 | self.eps = eps 80 | self.transpose = transpose 81 | self.G = G 82 | self.R = R 83 | 84 | def _embed(self, x: torch.Tensor): 85 | if self.transpose: 86 | x = x.transpose(1, 2) 87 | """ 88 | x = rearrange( 89 | x, "b t (g r) -> g b t r", g = self.G, r = self.R, 90 | ) 91 | """ 92 | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) 93 | feat = self.quantizer.get_output_from_indices(x) 94 | return feat.transpose_(1, 2) if self.transpose else feat 95 | 96 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 97 | return super().__call__(x) 98 | 99 | def forward(self, x: torch.Tensor) -> torch.Tensor: 100 | if self.transpose: 101 | x.transpose_(1, 2) 102 | # feat, ind = self.quantizer(x) 103 | with torch.autocast(device_type=str(x.device), dtype=torch.float32): 104 | _, ind = self.quantizer(x) 105 | """ 106 | ind = rearrange( 107 | ind, "g b t r ->b t (g r)", 108 | ) 109 | """ 110 | ind = ind.permute(1, 2, 0, 3).contiguous() 111 | ind = ind.view(ind.size(0), ind.size(1), -1) 112 | """ 113 | embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) 114 | embed_onehot = embed_onehot_tmp.to(x.dtype) 115 | del embed_onehot_tmp 116 | e_mean = torch.mean(embed_onehot, dim=[0, 1]) 117 | # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) 118 | torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) 119 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) 120 | 121 | return 122 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), 123 | feat.transpose_(1, 2) if self.transpose else feat, 124 | perplexity, 125 | """ 126 | return ind.transpose_(1, 2) if self.transpose else ind 127 | 128 | 129 | class DVAEDecoder(nn.Module): 130 | def __init__( 131 | self, 132 | idim: int, 133 | odim: int, 134 | n_layer=12, 135 | bn_dim=64, 136 | hidden=256, 137 | kernel=7, 138 | dilation=2, 139 | up=False, 140 | ): 141 | super().__init__() 142 | self.up = up 143 | self.conv_in = nn.Sequential( 144 | nn.Conv1d(idim, bn_dim, 3, 1, 1), 145 | nn.GELU(), 146 | nn.Conv1d(bn_dim, hidden, 3, 1, 1), 147 | ) 148 | self.decoder_block = nn.ModuleList( 149 | [ 150 | ConvNeXtBlock( 151 | hidden, 152 | hidden * 4, 153 | kernel, 154 | dilation, 155 | ) 156 | for _ in range(n_layer) 157 | ] 158 | ) 159 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) 160 | 161 | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor: 162 | # B, C, T 163 | y = self.conv_in(x) 164 | for f in self.decoder_block: 165 | y = f(y, conditioning) 166 | 167 | x = self.conv_out(y) 168 | return x 169 | 170 | 171 | class MelSpectrogramFeatures(torch.nn.Module): 172 | def __init__( 173 | self, 174 | sample_rate=24000, 175 | n_fft=1024, 176 | hop_length=256, 177 | n_mels=100, 178 | padding: Literal["center", "same"] = "center", 179 | ): 180 | super().__init__() 181 | if padding not in ["center", "same"]: 182 | raise ValueError("Padding must be 'center' or 'same'.") 183 | self.padding = padding 184 | self.mel_spec = torchaudio.transforms.MelSpectrogram( 185 | sample_rate=sample_rate, 186 | n_fft=n_fft, 187 | hop_length=hop_length, 188 | n_mels=n_mels, 189 | center=padding == "center", 190 | power=1, 191 | ) 192 | 193 | def __call__(self, audio: torch.Tensor) -> torch.Tensor: 194 | return super().__call__(audio) 195 | 196 | def forward(self, audio: torch.Tensor) -> torch.Tensor: 197 | mel: torch.Tensor = self.mel_spec(audio) 198 | features = torch.log(torch.clip(mel, min=1e-5)) 199 | return features 200 | 201 | 202 | class DVAE(nn.Module): 203 | def __init__( 204 | self, 205 | decoder_config: dict, 206 | encoder_config: Optional[dict] = None, 207 | vq_config: Optional[dict] = None, 208 | dim=512, 209 | coef: Optional[str] = None, 210 | **kwargs 211 | ): 212 | super().__init__() 213 | self.logger = logger.get_logger(self.__class__.__name__) 214 | 215 | if coef is None: 216 | coef = torch.rand(100) 217 | else: 218 | coef = torch.from_numpy( 219 | np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32)) 220 | ) 221 | 222 | self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2)) 223 | 224 | if encoder_config is not None: 225 | self.downsample_conv = nn.Sequential( 226 | nn.Conv1d(100, dim, 3, 1, 1), 227 | nn.GELU(), 228 | nn.Conv1d(dim, dim, 4, 2, 1), 229 | nn.GELU(), 230 | ) 231 | self.preprocessor_mel = MelSpectrogramFeatures() 232 | self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config) 233 | 234 | self.decoder = DVAEDecoder(**decoder_config) 235 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) 236 | if vq_config is not None: 237 | self.vq_layer = GFSQ(**vq_config) 238 | else: 239 | self.vq_layer = None 240 | 241 | self.model_path = kwargs.get("model_path", None) 242 | if self.model_path: 243 | self.logger.info(f"loading DVAE pretrained model: {self.model_path}") 244 | self.from_pretrained(self.model_path) 245 | 246 | def from_pretrained(self, file_path: str): 247 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True)) 248 | 249 | def __repr__(self) -> str: 250 | return b14.encode_to_string( 251 | self.coef.cpu().numpy().astype(np.float32).tobytes() 252 | ) 253 | 254 | def __call__( 255 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 256 | ) -> torch.Tensor: 257 | return super().__call__(inp, mode) 258 | 259 | @torch.inference_mode() 260 | def forward( 261 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode" 262 | ) -> torch.Tensor: 263 | if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None: 264 | mel = self.preprocessor_mel(inp) 265 | x: torch.Tensor = self.downsample_conv( 266 | torch.div(mel, self.coef.view(1, 100, 1).expand(mel.shape), out=mel), 267 | ) 268 | x = self.encoder(x) 269 | ind = self.vq_layer(x) 270 | return ind 271 | 272 | if self.vq_layer is not None: 273 | vq_feats = self.vq_layer._embed(inp) 274 | else: 275 | vq_feats = inp 276 | 277 | vq_feats = ( 278 | vq_feats.view( 279 | (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)), 280 | ) 281 | .permute(0, 2, 3, 1) 282 | .flatten(2) 283 | ) 284 | 285 | dec_out = self.out_conv( 286 | self.decoder( 287 | x=vq_feats, 288 | ), 289 | ) 290 | 291 | return torch.mul(dec_out, self.coef, out=dec_out) 292 | -------------------------------------------------------------------------------- /chattts_plus/models/gpt.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import platform 3 | from dataclasses import dataclass 4 | import logging 5 | from typing import Union, List, Optional, Tuple 6 | import gc 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.utils.parametrize as P 12 | from torch.nn.utils.parametrizations import weight_norm 13 | from tqdm import tqdm 14 | from transformers import LlamaConfig, LogitsWarper 15 | from transformers.cache_utils import Cache 16 | from transformers.modeling_outputs import BaseModelOutputWithPast 17 | from transformers.utils import is_flash_attn_2_available 18 | 19 | from .llama import LlamaModel 20 | from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat 21 | from ..commons import logger 22 | 23 | 24 | class GPT(nn.Module): 25 | def __init__( 26 | self, 27 | gpt_config: dict, 28 | num_audio_tokens: int = 626, 29 | num_text_tokens: int = 21178, 30 | num_vq=4, 31 | use_flash_attn=False, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self.logger = logger.get_logger(self.__class__.__name__) 36 | self.num_vq = num_vq 37 | self.num_audio_tokens = num_audio_tokens 38 | 39 | self.use_flash_attn = use_flash_attn 40 | 41 | self.gpt, self.llama_config = self._build_llama(gpt_config) 42 | self.is_te_llama = False 43 | self.model_dim = int(self.gpt.config.hidden_size) 44 | self.emb_code = nn.ModuleList( 45 | [ 46 | nn.Embedding( 47 | num_audio_tokens, 48 | self.model_dim 49 | ) 50 | for _ in range(num_vq) 51 | ], 52 | ) 53 | self.emb_text = nn.Embedding( 54 | num_text_tokens, self.model_dim 55 | ) 56 | 57 | self.head_text = weight_norm( 58 | nn.Linear( 59 | self.model_dim, 60 | num_text_tokens, 61 | bias=False, 62 | ), 63 | name="weight", 64 | ) 65 | self.head_code = nn.ModuleList( 66 | [ 67 | weight_norm( 68 | nn.Linear( 69 | self.model_dim, 70 | num_audio_tokens, 71 | bias=False, 72 | ), 73 | name="weight", 74 | ) 75 | for _ in range(self.num_vq) 76 | ], 77 | ) 78 | 79 | self.model_path = kwargs.get("model_path", None) 80 | if self.model_path: 81 | self.logger.info(f"loading GPT pretrained model: {self.model_path}") 82 | self.from_pretrained(self.model_path) 83 | 84 | def from_pretrained(self, file_path: str): 85 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True)) 86 | 87 | class Context: 88 | def __init__(self): 89 | self._interrupt = False 90 | 91 | def set(self, v: bool): 92 | self._interrupt = v 93 | 94 | def get(self) -> bool: 95 | return self._interrupt 96 | 97 | def _build_llama( 98 | self, 99 | config: dict, 100 | ) -> Tuple[LlamaModel, LlamaConfig]: 101 | 102 | if self.use_flash_attn and is_flash_attn_2_available(): 103 | llama_config = LlamaConfig( 104 | **config, 105 | attn_implementation="flash_attention_2", 106 | ) 107 | self.logger.info( 108 | "enabling flash_attention_2 may make gpt be even slower" 109 | ) 110 | else: 111 | llama_config = LlamaConfig(**config) 112 | 113 | model = LlamaModel(llama_config) 114 | del model.embed_tokens 115 | return model, llama_config 116 | 117 | def __call__( 118 | self, input_ids: torch.Tensor, text_mask: torch.Tensor 119 | ) -> torch.Tensor: 120 | """ 121 | get_emb 122 | """ 123 | return super().__call__(input_ids, text_mask) 124 | 125 | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: 126 | """ 127 | get_emb 128 | """ 129 | 130 | emb_text: torch.Tensor = self.emb_text( 131 | input_ids[text_mask].narrow(1, 0, 1).squeeze_(1) 132 | ) 133 | 134 | text_mask_inv = text_mask.logical_not() 135 | masked_input_ids: torch.Tensor = input_ids[text_mask_inv] 136 | 137 | emb_code = [ 138 | self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) 139 | ] 140 | emb_code = torch.stack(emb_code, 2).sum(2) 141 | 142 | emb = torch.zeros( 143 | (input_ids.shape[:-1]) + (emb_text.shape[-1],), 144 | device=emb_text.device, 145 | dtype=emb_text.dtype, 146 | ) 147 | emb[text_mask] = emb_text 148 | emb[text_mask_inv] = emb_code.to(emb.dtype) 149 | return emb 150 | 151 | @dataclass(repr=False, eq=False) 152 | class _GenerationInputs: 153 | position_ids: torch.Tensor 154 | cache_position: torch.Tensor 155 | use_cache: bool 156 | input_ids: Optional[torch.Tensor] = None 157 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 158 | attention_mask: Optional[torch.Tensor] = None 159 | inputs_embeds: Optional[torch.Tensor] = None 160 | 161 | def to(self, device: torch.device, dtype: torch.dtype): 162 | if self.attention_mask is not None: 163 | self.attention_mask = self.attention_mask.to(device, dtype=dtype) 164 | if self.position_ids is not None: 165 | self.position_ids = self.position_ids.to(device, dtype=dtype) 166 | if self.inputs_embeds is not None: 167 | self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype) 168 | if self.cache_position is not None: 169 | self.cache_position = self.cache_position.to(device, dtype=dtype) 170 | 171 | def _prepare_generation_inputs( 172 | self, 173 | input_ids: torch.Tensor, 174 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 175 | attention_mask: Optional[torch.Tensor] = None, 176 | inputs_embeds: Optional[torch.Tensor] = None, 177 | cache_position: Optional[torch.Tensor] = None, 178 | position_ids: Optional[torch.Tensor] = None, 179 | use_cache=True, 180 | ) -> _GenerationInputs: 181 | # With static cache, the `past_key_values` is None 182 | # TODO joao: standardize interface for the different Cache classes and remove of this if 183 | has_static_cache = False 184 | if past_key_values is None: 185 | if hasattr(self.gpt.layers[0], "self_attn"): 186 | past_key_values = getattr( 187 | self.gpt.layers[0].self_attn, "past_key_value", None 188 | ) 189 | has_static_cache = past_key_values is not None 190 | 191 | past_length = 0 192 | if past_key_values is not None: 193 | if isinstance(past_key_values, Cache): 194 | past_length = ( 195 | int(cache_position[0]) 196 | if cache_position is not None 197 | else past_key_values.get_seq_length() 198 | ) 199 | max_cache_length = past_key_values.get_max_length() 200 | cache_length = ( 201 | past_length 202 | if max_cache_length is None 203 | else min(max_cache_length, past_length) 204 | ) 205 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects 206 | else: 207 | cache_length = past_length = past_key_values[0][0].shape[2] 208 | max_cache_length = None 209 | 210 | # Keep only the unprocessed tokens: 211 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 212 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 213 | # input) 214 | if ( 215 | attention_mask is not None 216 | and attention_mask.shape[1] > input_ids.shape[1] 217 | ): 218 | start = attention_mask.shape[1] - past_length 219 | input_ids = input_ids.narrow(1, -start, start) 220 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 221 | # input_ids based on the past_length. 222 | elif past_length < input_ids.shape[1]: 223 | input_ids = input_ids.narrow( 224 | 1, past_length, input_ids.size(1) - past_length 225 | ) 226 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 227 | 228 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 229 | if ( 230 | max_cache_length is not None 231 | and attention_mask is not None 232 | and cache_length + input_ids.shape[1] > max_cache_length 233 | ): 234 | attention_mask = attention_mask.narrow( 235 | 1, -max_cache_length, max_cache_length 236 | ) 237 | 238 | if attention_mask is not None and position_ids is None: 239 | # create position_ids on the fly for batch generation 240 | position_ids = attention_mask.long().cumsum(-1) - 1 241 | position_ids.masked_fill_(attention_mask.eq(0), 1) 242 | if past_key_values: 243 | position_ids = position_ids.narrow( 244 | 1, -input_ids.shape[1], input_ids.shape[1] 245 | ) 246 | 247 | input_length = ( 248 | position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] 249 | ) 250 | if cache_position is None: 251 | cache_position = torch.arange( 252 | past_length, past_length + input_length, device=input_ids.device 253 | ) 254 | else: 255 | cache_position = cache_position.narrow(0, -input_length, input_length) 256 | 257 | if has_static_cache: 258 | past_key_values = None 259 | 260 | model_inputs = self._GenerationInputs( 261 | position_ids=position_ids, 262 | cache_position=cache_position, 263 | use_cache=use_cache, 264 | ) 265 | 266 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 267 | if inputs_embeds is not None and past_key_values is None: 268 | model_inputs.inputs_embeds = inputs_embeds 269 | else: 270 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 271 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 272 | # TODO: use `next_tokens` directly instead. 273 | model_inputs.input_ids = input_ids.contiguous() 274 | 275 | model_inputs.past_key_values = past_key_values 276 | model_inputs.attention_mask = attention_mask 277 | 278 | return model_inputs 279 | 280 | @dataclass(repr=False, eq=False) 281 | class GenerationOutputs: 282 | ids: List[torch.Tensor] 283 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] 284 | hiddens: List[torch.Tensor] 285 | 286 | def _prepare_generation_outputs( 287 | self, 288 | inputs_ids: torch.Tensor, 289 | start_idx: int, 290 | end_idx: torch.Tensor, 291 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]], 292 | hiddens: List[torch.Tensor], 293 | infer_text: bool, 294 | ) -> GenerationOutputs: 295 | inputs_ids = [ 296 | inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx) 297 | ] 298 | if infer_text: 299 | inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids] 300 | 301 | if len(hiddens) > 0: 302 | hiddens = torch.stack(hiddens, 1) 303 | hiddens = [ 304 | hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int()) 305 | ] 306 | 307 | return self.GenerationOutputs( 308 | ids=inputs_ids, 309 | attentions=attentions, 310 | hiddens=hiddens, 311 | ) 312 | 313 | @torch.no_grad() 314 | def generate( 315 | self, 316 | emb: torch.Tensor, 317 | inputs_ids: torch.Tensor, 318 | temperature: torch.Tensor, 319 | eos_token: Union[int, torch.Tensor], 320 | attention_mask: Optional[torch.Tensor] = None, 321 | max_new_token=2048, 322 | min_new_token=0, 323 | logits_warpers: List[LogitsWarper] = [], 324 | logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], 325 | infer_text=False, 326 | return_attn=False, 327 | return_hidden=False, 328 | stream=False, 329 | show_tqdm=True, 330 | ensure_non_empty=True, 331 | stream_batch=24, 332 | context=Context(), 333 | ): 334 | 335 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = [] 336 | hiddens = [] 337 | stream_iter = 0 338 | 339 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros( 340 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long 341 | ) 342 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() 343 | 344 | old_temperature = temperature 345 | 346 | temperature = ( 347 | temperature.unsqueeze(0) 348 | .expand(inputs_ids.shape[0], -1) 349 | .contiguous() 350 | .view(-1, 1) 351 | ) 352 | 353 | attention_mask_cache = torch.ones( 354 | ( 355 | inputs_ids.shape[0], 356 | inputs_ids.shape[1] + max_new_token, 357 | ), 358 | dtype=torch.bool, 359 | device=inputs_ids.device, 360 | ) 361 | if attention_mask is not None: 362 | attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( 363 | attention_mask 364 | ) 365 | 366 | progress = inputs_ids.size(1) 367 | # pre-allocate inputs_ids 368 | inputs_ids_buf = torch.zeros( 369 | inputs_ids.size(0), 370 | progress + max_new_token, 371 | inputs_ids.size(2), 372 | dtype=inputs_ids.dtype, 373 | device=inputs_ids.device, 374 | ) 375 | inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids) 376 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress) 377 | 378 | pbar: Optional[tqdm] = None 379 | 380 | if show_tqdm: 381 | pbar = tqdm( 382 | total=max_new_token, 383 | desc="text" if infer_text else "code", 384 | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", 385 | ) 386 | 387 | past_key_values = None 388 | 389 | for i in range(max_new_token): 390 | 391 | model_input = self._prepare_generation_inputs( 392 | inputs_ids, 393 | past_key_values, 394 | attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]), 395 | use_cache=not self.is_te_llama, 396 | ) 397 | 398 | if i > 0: 399 | inputs_ids_emb = model_input.input_ids.to(emb.device) 400 | if infer_text: 401 | emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0]) 402 | else: 403 | code_emb = [ 404 | self.emb_code[i](inputs_ids_emb[:, :, i]) 405 | for i in range(self.num_vq) 406 | ] 407 | emb = torch.stack(code_emb, 3).sum(3) 408 | model_input.inputs_embeds = emb 409 | model_input.to(emb.device, emb.dtype) 410 | outputs: BaseModelOutputWithPast = self.gpt( 411 | attention_mask=model_input.attention_mask, 412 | position_ids=model_input.position_ids, 413 | past_key_values=model_input.past_key_values, 414 | inputs_embeds=model_input.inputs_embeds, 415 | use_cache=model_input.use_cache, 416 | output_attentions=return_attn, 417 | cache_position=model_input.cache_position, 418 | ) 419 | attentions.append(outputs.attentions) 420 | hidden_states = outputs.last_hidden_state 421 | past_key_values = outputs.past_key_values 422 | if return_hidden: 423 | hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1)) 424 | with P.cached(): 425 | if infer_text: 426 | logits: torch.Tensor = self.head_text(hidden_states) 427 | else: 428 | # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) 429 | logits = torch.empty( 430 | hidden_states.size(0), 431 | hidden_states.size(1), 432 | self.num_audio_tokens, 433 | self.num_vq, 434 | dtype=emb.dtype, 435 | device=emb.device, 436 | ) 437 | for num_vq_iter in range(self.num_vq): 438 | x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) 439 | logits[..., num_vq_iter] = x 440 | 441 | # logits = logits[:, -1].float() 442 | logits = logits.narrow(1, -1, 1).squeeze_(1) 443 | 444 | if not infer_text: 445 | # logits = rearrange(logits, "b c n -> (b n) c") 446 | logits = logits.permute(0, 2, 1) 447 | logits = logits.reshape(-1, logits.size(2)) 448 | # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") 449 | inputs_ids_sliced = inputs_ids.narrow( 450 | 1, 451 | start_idx, 452 | inputs_ids.size(1) - start_idx, 453 | ).permute(0, 2, 1) 454 | logits_token = inputs_ids_sliced.reshape( 455 | inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), 456 | -1, 457 | ).to(emb.device) 458 | else: 459 | logits_token = ( 460 | inputs_ids.narrow( 461 | 1, 462 | start_idx, 463 | inputs_ids.size(1) - start_idx, 464 | ) 465 | .narrow(2, 0, 1) 466 | .to(emb.device) 467 | ) 468 | 469 | logits /= temperature 470 | 471 | for logitsProcessors in logits_processors: 472 | logits = logitsProcessors(logits_token, logits) 473 | 474 | for logitsWarpers in logits_warpers: 475 | logits = logitsWarpers(logits_token, logits) 476 | 477 | if i < min_new_token: 478 | logits[:, eos_token] = -torch.inf 479 | 480 | scores = F.softmax(logits, dim=-1) 481 | idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) 482 | 483 | if not infer_text: 484 | # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) 485 | idx_next = idx_next.view(-1, self.num_vq) 486 | finish_or = idx_next.eq(eos_token).any(1) 487 | finish.logical_or_(finish_or) 488 | inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) 489 | else: 490 | finish_or = idx_next.eq(eos_token).any(1) 491 | finish.logical_or_(finish_or) 492 | inputs_ids_buf.narrow(1, progress, 1).copy_( 493 | idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq), 494 | ) 495 | 496 | if i == 0 and finish.any(): 497 | self.logger.info( 498 | "unexpected end at index %s" % str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]), 499 | ) 500 | if ensure_non_empty: 501 | if show_tqdm: 502 | pbar.close() 503 | self.logger.info("regenerate in order to ensure non-empty") 504 | new_gen = self.generate( 505 | emb, 506 | inputs_ids, 507 | old_temperature, 508 | eos_token, 509 | attention_mask, 510 | max_new_token, 511 | min_new_token, 512 | logits_warpers, 513 | logits_processors, 514 | infer_text, 515 | return_attn, 516 | return_hidden, 517 | stream, 518 | show_tqdm, 519 | ensure_non_empty, 520 | stream_batch, 521 | context, 522 | ) 523 | for result in new_gen: 524 | yield result 525 | return 526 | 527 | progress += 1 528 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress) 529 | 530 | not_finished = finish.logical_not().to(end_idx.device) 531 | end_idx.add_(not_finished.int()) 532 | stream_iter += not_finished.any().int() 533 | if stream: 534 | if stream_iter > 0 and stream_iter % stream_batch == 0: 535 | self.logger.info("yield stream result, end: %d", end_idx) 536 | yield self._prepare_generation_outputs( 537 | inputs_ids, 538 | start_idx, 539 | end_idx, 540 | attentions, 541 | hiddens, 542 | infer_text, 543 | ) 544 | 545 | if finish.all() or context.get(): 546 | break 547 | 548 | if pbar is not None: 549 | pbar.update(1) 550 | 551 | if pbar is not None: 552 | pbar.close() 553 | 554 | if not finish.all(): 555 | if context.get(): 556 | self.logger.info("generation is interrupted") 557 | else: 558 | self.logger.info( 559 | f"incomplete result. hit max_new_token: {max_new_token}" 560 | ) 561 | 562 | yield self._prepare_generation_outputs( 563 | inputs_ids, 564 | start_idx, 565 | end_idx, 566 | attentions, 567 | hiddens, 568 | infer_text, 569 | ) 570 | -------------------------------------------------------------------------------- /chattts_plus/models/processors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper 4 | 5 | 6 | class CustomRepetitionPenaltyLogitsProcessorRepeat: 7 | 8 | def __init__(self, penalty: float, max_input_ids: int, past_window: int): 9 | if not isinstance(penalty, float) or not (penalty > 0): 10 | raise ValueError( 11 | f"`penalty` has to be a strictly positive float, but is {penalty}" 12 | ) 13 | 14 | self.penalty = penalty 15 | self.max_input_ids = max_input_ids 16 | self.past_window = past_window 17 | 18 | def __call__( 19 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 20 | ) -> torch.FloatTensor: 21 | if input_ids.size(1) > self.past_window: 22 | input_ids = input_ids.narrow(1, -self.past_window, self.past_window) 23 | freq = F.one_hot(input_ids, scores.size(1)).sum(1) 24 | if freq.size(0) > self.max_input_ids: 25 | freq.narrow( 26 | 0, self.max_input_ids, freq.size(0) - self.max_input_ids 27 | ).zero_() 28 | alpha = torch.pow(self.penalty, freq) 29 | scores = scores.contiguous() 30 | inp = scores.multiply(alpha) 31 | oth = scores.divide(alpha) 32 | con = scores < 0 33 | out = torch.where(con, inp, oth) 34 | return out 35 | 36 | 37 | def gen_logits( 38 | num_code: int, 39 | top_P=0.7, 40 | top_K=20, 41 | repetition_penalty=1.0, 42 | ): 43 | logits_warpers = [] 44 | if top_P is not None: 45 | logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3)) 46 | if top_K is not None: 47 | logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3)) 48 | 49 | logits_processors = [] 50 | if repetition_penalty is not None and repetition_penalty != 1: 51 | logits_processors.append( 52 | CustomRepetitionPenaltyLogitsProcessorRepeat( 53 | repetition_penalty, num_code, 16 54 | ) 55 | ) 56 | 57 | return logits_warpers, logits_processors 58 | -------------------------------------------------------------------------------- /chattts_plus/models/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | """ 5 | https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning 6 | """ 7 | 8 | from typing import List, Tuple, Optional 9 | import lzma 10 | 11 | import numpy as np 12 | import pybase16384 as b14 13 | import torch 14 | import torch.nn.functional as F 15 | from transformers import BertTokenizerFast 16 | 17 | from ..commons import logger 18 | 19 | 20 | class Tokenizer: 21 | def __init__( 22 | self, model_path, **kwargs 23 | ): 24 | self.logger = logger.get_logger(self.__class__.__name__) 25 | self.logger.info(f"loading Tokenizer pretrained model: {model_path}") 26 | 27 | tokenizer: BertTokenizerFast = torch.load( 28 | model_path, map_location="cpu", mmap=True 29 | ) 30 | self._tokenizer = tokenizer 31 | 32 | # 设置特殊token 33 | self._tokenizer.eos_token = '[SEP]' # 使用BERT的默认结束符 34 | self._tokenizer.pad_token = '[PAD]' # 使用BERT的默认填充符 35 | 36 | # 获取或设置对应的token id 37 | if not hasattr(self._tokenizer, 'pad_token_id'): 38 | self._tokenizer.pad_token_id = self._tokenizer.convert_tokens_to_ids('[PAD]') 39 | if not hasattr(self._tokenizer, 'eos_token_id'): 40 | self._tokenizer.eos_token_id = self._tokenizer.convert_tokens_to_ids('[SEP]') 41 | 42 | self.len = len(tokenizer) 43 | self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]") 44 | self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]") 45 | self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]") 46 | 47 | self.decode = self._tokenizer.batch_decode 48 | 49 | @torch.inference_mode() 50 | def encode( 51 | self, 52 | text: List[str], 53 | num_vq: int, 54 | prompt_str: Optional[str] = None, 55 | device="cpu", 56 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 57 | 58 | input_ids_lst = [] 59 | attention_mask_lst = [] 60 | max_input_ids_len = -1 61 | max_attention_mask_len = -1 62 | prompt_size = 0 63 | 64 | prompt = self._decode_prompt(prompt_str) if prompt_str is not None else None 65 | 66 | if prompt is not None: 67 | assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq" 68 | prompt_size = prompt.size(1) 69 | 70 | # avoid random speaker embedding of tokenizer in the other dims 71 | for t in text: 72 | x = self._tokenizer.encode_plus( 73 | t, return_tensors="pt", add_special_tokens=False, padding=True 74 | ) 75 | input_ids_lst.append(x["input_ids"].squeeze_(0)) 76 | attention_mask_lst.append(x["attention_mask"].squeeze_(0)) 77 | ids_sz = input_ids_lst[-1].size(0) 78 | if ids_sz > max_input_ids_len: 79 | max_input_ids_len = ids_sz 80 | attn_sz = attention_mask_lst[-1].size(0) 81 | if attn_sz > max_attention_mask_len: 82 | max_attention_mask_len = attn_sz 83 | 84 | if prompt is not None: 85 | max_input_ids_len += prompt_size 86 | max_attention_mask_len += prompt_size 87 | 88 | input_ids = torch.zeros( 89 | len(input_ids_lst), 90 | max_input_ids_len, 91 | device=device, 92 | dtype=input_ids_lst[0].dtype, 93 | ) 94 | for i in range(len(input_ids_lst)): 95 | input_ids.narrow(0, i, 1).narrow( 96 | 1, 97 | max_input_ids_len - prompt_size - input_ids_lst[i].size(0), 98 | input_ids_lst[i].size(0), 99 | ).copy_( 100 | input_ids_lst[i] 101 | ) # left padding 102 | 103 | attention_mask = torch.zeros( 104 | len(attention_mask_lst), 105 | max_attention_mask_len, 106 | device=device, 107 | dtype=attention_mask_lst[0].dtype, 108 | ) 109 | for i in range(len(attention_mask_lst)): 110 | attn = attention_mask.narrow(0, i, 1) 111 | attn.narrow( 112 | 1, 113 | max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0), 114 | attention_mask_lst[i].size(0), 115 | ).copy_( 116 | attention_mask_lst[i] 117 | ) # left padding 118 | if prompt_size > 0: 119 | attn.narrow( 120 | 1, 121 | max_attention_mask_len - prompt_size, 122 | prompt_size, 123 | ).fill_(1) 124 | 125 | text_mask = attention_mask.bool() 126 | new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone() 127 | 128 | if prompt_size > 0: 129 | text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0) 130 | prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1) 131 | new_input_ids.narrow( 132 | 1, 133 | max_input_ids_len - prompt_size, 134 | prompt_size, 135 | ).copy_(prompt_t) 136 | 137 | return new_input_ids, attention_mask, text_mask 138 | 139 | @staticmethod 140 | def _decode_spk_emb(spk_emb: str) -> np.ndarray: 141 | return np.frombuffer( 142 | lzma.decompress( 143 | b14.decode_from_string(spk_emb), 144 | format=lzma.FORMAT_RAW, 145 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 146 | ), 147 | dtype=np.float16, 148 | ).copy() 149 | 150 | @torch.no_grad() 151 | def apply_spk_emb( 152 | self, 153 | emb: torch.Tensor, 154 | spk_emb, 155 | input_ids: torch.Tensor, 156 | device: torch.device, 157 | ): 158 | if isinstance(spk_emb, str): 159 | spk_emb_tensor = torch.from_numpy(self._decode_spk_emb(spk_emb)) 160 | else: 161 | spk_emb_tensor = spk_emb 162 | 163 | n = ( 164 | F.normalize( 165 | spk_emb_tensor, 166 | p=2.0, 167 | dim=0, 168 | eps=1e-12, 169 | ) 170 | .to(emb.device, dtype=emb.dtype) 171 | .unsqueeze_(0) 172 | .expand(emb.size(0), -1) 173 | .unsqueeze_(1) 174 | .expand(emb.shape) 175 | ) 176 | cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape) 177 | torch.where(cond, n, emb, out=emb) 178 | return emb 179 | 180 | @staticmethod 181 | @torch.no_grad() 182 | def _decode_prompt(prompt: str) -> torch.Tensor: 183 | dec = b14.decode_from_string(prompt) 184 | shp = np.frombuffer(dec[:4], dtype=" str: 198 | arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() 199 | shp = arr.shape 200 | assert len(shp) == 2, "prompt must be a 2D tensor" 201 | s = b14.encode_to_string( 202 | np.array(shp, dtype=" str: 214 | arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() 215 | s = b14.encode_to_string( 216 | lzma.compress( 217 | arr.tobytes(), 218 | format=lzma.FORMAT_RAW, 219 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}], 220 | ), 221 | ) 222 | return s 223 | -------------------------------------------------------------------------------- /chattts_plus/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/27 22:53 3 | # @Project : ChatTTSPlus 4 | # @FileName: __init__.py.py 5 | -------------------------------------------------------------------------------- /chattts_plus/trt_models/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/10/6 19:59 3 | # @Author : wenshao 4 | # @ProjectName: ChatTTSPlus 5 | # @FileName: __init__.py.py 6 | 7 | from .gpt_trt import GPT -------------------------------------------------------------------------------- /chattts_plus/trt_models/base_model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from .predictor import get_predictor 4 | 5 | 6 | class BaseModel: 7 | """ 8 | 模型预测的基类 9 | """ 10 | 11 | def __init__(self, **kwargs): 12 | self.kwargs = copy.deepcopy(kwargs) 13 | self.predictor = get_predictor(**self.kwargs) 14 | self.device = torch.cuda.current_device() 15 | 16 | if self.predictor is not None: 17 | self.input_shapes = self.predictor.input_spec() 18 | self.output_shapes = self.predictor.output_spec() 19 | 20 | def input_process(self, *data): 21 | """ 22 | 输入预处理 23 | :return: 24 | """ 25 | pass 26 | 27 | def output_process(self, *data): 28 | """ 29 | 输出后处理 30 | :return: 31 | """ 32 | pass 33 | 34 | def predict(self, *data): 35 | """ 36 | 预测 37 | :return: 38 | """ 39 | pass 40 | 41 | def __del__(self): 42 | """ 43 | 删除实例 44 | :return: 45 | """ 46 | if self.predictor is not None: 47 | del self.predictor 48 | -------------------------------------------------------------------------------- /chattts_plus/trt_models/gpt_trt.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import platform 3 | from dataclasses import dataclass 4 | import logging 5 | from typing import Union, List, Optional, Tuple 6 | import gc 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.utils.parametrize as P 12 | from torch.nn.utils.parametrizations import weight_norm 13 | from tqdm import tqdm 14 | from transformers import LlamaConfig, LogitsWarper 15 | from transformers.cache_utils import Cache 16 | from transformers.modeling_outputs import BaseModelOutputWithPast 17 | from transformers.utils import is_flash_attn_2_available 18 | 19 | from .llama_trt_model import LlamaTRTModel 20 | from ..models.processors import CustomRepetitionPenaltyLogitsProcessorRepeat 21 | from ..commons import logger 22 | 23 | 24 | class GPT(nn.Module): 25 | def __init__( 26 | self, 27 | gpt_config: dict, 28 | num_audio_tokens: int = 626, 29 | num_text_tokens: int = 21178, 30 | num_vq=4, 31 | use_flash_attn=False, 32 | **kwargs 33 | ): 34 | super().__init__() 35 | self.logger = logger.get_logger(self.__class__.__name__) 36 | self.num_vq = num_vq 37 | self.num_audio_tokens = num_audio_tokens 38 | 39 | self.use_flash_attn = use_flash_attn 40 | self.gpt_config = gpt_config 41 | self.gpt, self.llama_config = self._build_llama(gpt_config, **kwargs) 42 | self.is_te_llama = False 43 | self.model_dim = int(gpt_config["hidden_size"]) 44 | self.emb_code = nn.ModuleList( 45 | [ 46 | nn.Embedding( 47 | num_audio_tokens, 48 | self.model_dim 49 | ) 50 | for _ in range(num_vq) 51 | ], 52 | ) 53 | self.emb_text = nn.Embedding( 54 | num_text_tokens, self.model_dim 55 | ) 56 | 57 | self.head_text = weight_norm( 58 | nn.Linear( 59 | self.model_dim, 60 | num_text_tokens, 61 | bias=False, 62 | ), 63 | name="weight", 64 | ) 65 | self.head_code = nn.ModuleList( 66 | [ 67 | weight_norm( 68 | nn.Linear( 69 | self.model_dim, 70 | num_audio_tokens, 71 | bias=False, 72 | ), 73 | name="weight", 74 | ) 75 | for _ in range(self.num_vq) 76 | ], 77 | ) 78 | self.model_path = kwargs.get("model_path", None) 79 | if self.model_path: 80 | self.logger.info(f"loading GPT pretrained model: {self.model_path}") 81 | self.from_pretrained(self.model_path) 82 | 83 | def from_pretrained(self, file_path: str): 84 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True), strict=False) 85 | 86 | class Context: 87 | def __init__(self): 88 | self._interrupt = False 89 | 90 | def set(self, v: bool): 91 | self._interrupt = v 92 | 93 | def get(self) -> bool: 94 | return self._interrupt 95 | 96 | def _build_llama( 97 | self, 98 | config: dict, 99 | **kwargs 100 | ): 101 | 102 | if self.use_flash_attn and is_flash_attn_2_available(): 103 | llama_config = LlamaConfig( 104 | **config, 105 | attn_implementation="flash_attention_2", 106 | ) 107 | self.logger.info( 108 | "enabling flash_attention_2 may make gpt be even slower" 109 | ) 110 | else: 111 | llama_config = LlamaConfig(**config) 112 | 113 | model = LlamaTRTModel(predict_type="trt", model_path=kwargs.get("trt_model_path"), 114 | output_max_shapes=kwargs.get("output_max_shapes")) 115 | 116 | return model, llama_config 117 | 118 | def __call__( 119 | self, input_ids: torch.Tensor, text_mask: torch.Tensor 120 | ) -> torch.Tensor: 121 | """ 122 | get_emb 123 | """ 124 | return super().__call__(input_ids, text_mask) 125 | 126 | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: 127 | """ 128 | get_emb 129 | """ 130 | 131 | emb_text: torch.Tensor = self.emb_text( 132 | input_ids[text_mask].narrow(1, 0, 1).squeeze_(1) 133 | ) 134 | 135 | text_mask_inv = text_mask.logical_not() 136 | masked_input_ids: torch.Tensor = input_ids[text_mask_inv] 137 | 138 | emb_code = [ 139 | self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq) 140 | ] 141 | emb_code = torch.stack(emb_code, 2).sum(2) 142 | 143 | emb = torch.zeros( 144 | (input_ids.shape[:-1]) + (emb_text.shape[-1],), 145 | device=emb_text.device, 146 | dtype=emb_text.dtype, 147 | ) 148 | emb[text_mask] = emb_text 149 | emb[text_mask_inv] = emb_code.to(emb.dtype) 150 | return emb 151 | 152 | @dataclass(repr=False, eq=False) 153 | class _GenerationInputs: 154 | position_ids: torch.Tensor 155 | cache_position: torch.Tensor 156 | use_cache: bool 157 | input_ids: Optional[torch.Tensor] = None 158 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 159 | attention_mask: Optional[torch.Tensor] = None 160 | inputs_embeds: Optional[torch.Tensor] = None 161 | 162 | def to(self, device: torch.device, dtype: torch.dtype): 163 | if self.attention_mask is not None: 164 | self.attention_mask = self.attention_mask.to(device, dtype=dtype) 165 | if self.position_ids is not None: 166 | self.position_ids = self.position_ids.to(device, dtype=dtype) 167 | if self.inputs_embeds is not None: 168 | self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype) 169 | if self.cache_position is not None: 170 | self.cache_position = self.cache_position.to(device, dtype=dtype) 171 | 172 | def _prepare_generation_inputs( 173 | self, 174 | input_ids: torch.Tensor, 175 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, 176 | attention_mask: Optional[torch.Tensor] = None, 177 | inputs_embeds: Optional[torch.Tensor] = None, 178 | cache_position: Optional[torch.Tensor] = None, 179 | position_ids: Optional[torch.Tensor] = None, 180 | use_cache=True, 181 | ) -> _GenerationInputs: 182 | # With static cache, the `past_key_values` is None 183 | # TODO joao: standardize interface for the different Cache classes and remove of this if 184 | has_static_cache = False 185 | 186 | past_length = 0 187 | if past_key_values is not None: 188 | if isinstance(past_key_values, Cache): 189 | past_length = ( 190 | int(cache_position[0]) 191 | if cache_position is not None 192 | else past_key_values.get_seq_length() 193 | ) 194 | max_cache_length = past_key_values.get_max_length() 195 | cache_length = ( 196 | past_length 197 | if max_cache_length is None 198 | else min(max_cache_length, past_length) 199 | ) 200 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects 201 | else: 202 | cache_length = past_length = past_key_values[0][0].shape[2] 203 | max_cache_length = None 204 | 205 | # Keep only the unprocessed tokens: 206 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 207 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 208 | # input) 209 | if ( 210 | attention_mask is not None 211 | and attention_mask.shape[1] > input_ids.shape[1] 212 | ): 213 | start = attention_mask.shape[1] - past_length 214 | input_ids = input_ids.narrow(1, -start, start) 215 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 216 | # input_ids based on the past_length. 217 | elif past_length < input_ids.shape[1]: 218 | input_ids = input_ids.narrow( 219 | 1, past_length, input_ids.size(1) - past_length 220 | ) 221 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 222 | 223 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 224 | if ( 225 | max_cache_length is not None 226 | and attention_mask is not None 227 | and cache_length + input_ids.shape[1] > max_cache_length 228 | ): 229 | attention_mask = attention_mask.narrow( 230 | 1, -max_cache_length, max_cache_length 231 | ) 232 | 233 | if attention_mask is not None and position_ids is None: 234 | # create position_ids on the fly for batch generation 235 | position_ids = attention_mask.long().cumsum(-1) - 1 236 | position_ids.masked_fill_(attention_mask.eq(0), 1) 237 | if past_key_values is not None: 238 | position_ids = position_ids.narrow( 239 | 1, -input_ids.shape[1], input_ids.shape[1] 240 | ) 241 | 242 | input_length = ( 243 | position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] 244 | ) 245 | if cache_position is None: 246 | cache_position = torch.arange( 247 | past_length, past_length + input_length, device=input_ids.device 248 | ) 249 | else: 250 | cache_position = cache_position.narrow(0, -input_length, input_length) 251 | 252 | if has_static_cache: 253 | past_key_values = None 254 | 255 | model_inputs = self._GenerationInputs( 256 | position_ids=position_ids, 257 | cache_position=cache_position, 258 | use_cache=use_cache, 259 | ) 260 | 261 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 262 | if inputs_embeds is not None and past_key_values is None: 263 | model_inputs.inputs_embeds = inputs_embeds 264 | else: 265 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise 266 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 267 | # TODO: use `next_tokens` directly instead. 268 | model_inputs.input_ids = input_ids.contiguous() 269 | 270 | model_inputs.past_key_values = past_key_values 271 | model_inputs.attention_mask = attention_mask 272 | 273 | return model_inputs 274 | 275 | @dataclass(repr=False, eq=False) 276 | class GenerationOutputs: 277 | ids: List[torch.Tensor] 278 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] 279 | hiddens: List[torch.Tensor] 280 | 281 | def _prepare_generation_outputs( 282 | self, 283 | inputs_ids: torch.Tensor, 284 | start_idx: int, 285 | end_idx: torch.Tensor, 286 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]], 287 | hiddens: List[torch.Tensor], 288 | infer_text: bool, 289 | ) -> GenerationOutputs: 290 | inputs_ids = [ 291 | inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx) 292 | ] 293 | if infer_text: 294 | inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids] 295 | 296 | if len(hiddens) > 0: 297 | hiddens = torch.stack(hiddens, 1) 298 | hiddens = [ 299 | hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int()) 300 | ] 301 | 302 | return self.GenerationOutputs( 303 | ids=inputs_ids, 304 | attentions=attentions, 305 | hiddens=hiddens, 306 | ) 307 | 308 | @torch.no_grad() 309 | def generate( 310 | self, 311 | emb: torch.Tensor, 312 | inputs_ids: torch.Tensor, 313 | temperature: torch.Tensor, 314 | eos_token: Union[int, torch.Tensor], 315 | attention_mask: Optional[torch.Tensor] = None, 316 | max_new_token=2048, 317 | min_new_token=0, 318 | logits_warpers: List[LogitsWarper] = [], 319 | logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [], 320 | infer_text=False, 321 | return_attn=False, 322 | return_hidden=False, 323 | stream=False, 324 | show_tqdm=True, 325 | ensure_non_empty=True, 326 | stream_batch=24, 327 | context=Context(), 328 | ): 329 | 330 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = [] 331 | hiddens = [] 332 | stream_iter = 0 333 | 334 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros( 335 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long 336 | ) 337 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() 338 | 339 | old_temperature = temperature 340 | 341 | temperature = ( 342 | temperature.unsqueeze(0) 343 | .expand(inputs_ids.shape[0], -1) 344 | .contiguous() 345 | .view(-1, 1) 346 | ) 347 | 348 | attention_mask_cache = torch.ones( 349 | ( 350 | inputs_ids.shape[0], 351 | inputs_ids.shape[1] + max_new_token, 352 | ), 353 | dtype=torch.bool, 354 | device=inputs_ids.device, 355 | ) 356 | if attention_mask is not None: 357 | attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( 358 | attention_mask 359 | ) 360 | 361 | progress = inputs_ids.size(1) 362 | # pre-allocate inputs_ids 363 | inputs_ids_buf = torch.zeros( 364 | inputs_ids.size(0), 365 | progress + max_new_token, 366 | inputs_ids.size(2), 367 | dtype=inputs_ids.dtype, 368 | device=inputs_ids.device, 369 | ) 370 | inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids) 371 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress) 372 | 373 | pbar: Optional[tqdm] = None 374 | 375 | if show_tqdm: 376 | pbar = tqdm( 377 | total=max_new_token, 378 | desc="text" if infer_text else "code", 379 | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", 380 | ) 381 | 382 | past_key_values = None 383 | # need to create kv cache first 384 | self.gpt.create_kv_cache(inputs_ids.size(0)) 385 | for i in range(max_new_token): 386 | 387 | model_input = self._prepare_generation_inputs( 388 | inputs_ids, 389 | past_key_values, 390 | attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]), 391 | use_cache=not self.is_te_llama, 392 | ) 393 | 394 | if i > 0: 395 | inputs_ids_emb = model_input.input_ids.to(emb.device) 396 | if infer_text: 397 | emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0]) 398 | else: 399 | code_emb = [ 400 | self.emb_code[i](inputs_ids_emb[:, :, i]) 401 | for i in range(self.num_vq) 402 | ] 403 | emb = torch.stack(code_emb, 3).sum(3) 404 | model_input.inputs_embeds = emb 405 | model_input.to(emb.device, dtype=emb.dtype) 406 | hidden_states = self.gpt.predict(model_input.inputs_embeds, model_input.attention_mask, 407 | model_input.position_ids) 408 | hidden_states = hidden_states.to(emb.device, dtype=emb.dtype) 409 | past_key_values = self.gpt.get_cur_kv_caches() 410 | attentions.append(None) 411 | if return_hidden: 412 | hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1)) 413 | 414 | with P.cached(): 415 | if infer_text: 416 | logits: torch.Tensor = self.head_text(hidden_states) 417 | else: 418 | # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) 419 | logits = torch.empty( 420 | hidden_states.size(0), 421 | hidden_states.size(1), 422 | self.num_audio_tokens, 423 | self.num_vq, 424 | dtype=emb.dtype, 425 | device=emb.device, 426 | ) 427 | for num_vq_iter in range(self.num_vq): 428 | x: torch.Tensor = self.head_code[num_vq_iter](hidden_states) 429 | logits[..., num_vq_iter] = x 430 | 431 | # logits = logits[:, -1].float() 432 | logits = logits.narrow(1, -1, 1).squeeze_(1) 433 | 434 | if not infer_text: 435 | # logits = rearrange(logits, "b c n -> (b n) c") 436 | logits = logits.permute(0, 2, 1) 437 | logits = logits.reshape(-1, logits.size(2)) 438 | # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") 439 | inputs_ids_sliced = inputs_ids.narrow( 440 | 1, 441 | start_idx, 442 | inputs_ids.size(1) - start_idx, 443 | ).permute(0, 2, 1) 444 | logits_token = inputs_ids_sliced.reshape( 445 | inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), 446 | -1, 447 | ).to(emb.device) 448 | else: 449 | logits_token = ( 450 | inputs_ids.narrow( 451 | 1, 452 | start_idx, 453 | inputs_ids.size(1) - start_idx, 454 | ) 455 | .narrow(2, 0, 1) 456 | .to(emb.device) 457 | ) 458 | 459 | logits /= temperature 460 | 461 | for logitsProcessors in logits_processors: 462 | logits = logitsProcessors(logits_token, logits) 463 | 464 | for logitsWarpers in logits_warpers: 465 | logits = logitsWarpers(logits_token, logits) 466 | 467 | if i < min_new_token: 468 | logits[:, eos_token] = -torch.inf 469 | 470 | scores = F.softmax(logits, dim=-1) 471 | idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) 472 | 473 | if not infer_text: 474 | # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) 475 | idx_next = idx_next.view(-1, self.num_vq) 476 | finish_or = idx_next.eq(eos_token).any(1) 477 | finish.logical_or_(finish_or) 478 | inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1)) 479 | else: 480 | finish_or = idx_next.eq(eos_token).any(1) 481 | finish.logical_or_(finish_or) 482 | inputs_ids_buf.narrow(1, progress, 1).copy_( 483 | idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq), 484 | ) 485 | 486 | if i == 0 and finish.any(): 487 | self.logger.info( 488 | "unexpected end at index %s" % str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]), 489 | ) 490 | if ensure_non_empty: 491 | if show_tqdm: 492 | pbar.close() 493 | self.logger.info("regenerate in order to ensure non-empty") 494 | new_gen = self.generate( 495 | emb, 496 | inputs_ids, 497 | old_temperature, 498 | eos_token, 499 | attention_mask, 500 | max_new_token, 501 | min_new_token, 502 | logits_warpers, 503 | logits_processors, 504 | infer_text, 505 | return_attn, 506 | return_hidden, 507 | stream, 508 | show_tqdm, 509 | ensure_non_empty, 510 | stream_batch, 511 | context, 512 | ) 513 | for result in new_gen: 514 | yield result 515 | return 516 | 517 | progress += 1 518 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress) 519 | 520 | not_finished = finish.logical_not().to(end_idx.device) 521 | end_idx.add_(not_finished.int()) 522 | stream_iter += not_finished.any().int() 523 | if stream: 524 | if stream_iter > 0 and stream_iter % stream_batch == 0: 525 | self.logger.info("yield stream result, end: %d", end_idx) 526 | yield self._prepare_generation_outputs( 527 | inputs_ids, 528 | start_idx, 529 | end_idx, 530 | attentions, 531 | hiddens, 532 | infer_text, 533 | ) 534 | 535 | if finish.all() or context.get(): 536 | break 537 | 538 | if pbar is not None: 539 | pbar.update(1) 540 | 541 | if pbar is not None: 542 | pbar.close() 543 | 544 | if not finish.all(): 545 | if context.get(): 546 | self.logger.info("generation is interrupted") 547 | else: 548 | self.logger.info( 549 | f"incomplete result. hit max_new_token: {max_new_token}" 550 | ) 551 | 552 | yield self._prepare_generation_outputs( 553 | inputs_ids, 554 | start_idx, 555 | end_idx, 556 | attentions, 557 | hiddens, 558 | infer_text, 559 | ) 560 | -------------------------------------------------------------------------------- /chattts_plus/trt_models/llama_trt_model.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import time 3 | import numpy as np 4 | import torch 5 | from torch.cuda import nvtx 6 | from .base_model import BaseModel 7 | from .predictor import numpy_to_torch_dtype_dict 8 | 9 | 10 | class LlamaTRTModel(BaseModel): 11 | """ 12 | Llama TensorRT Model 13 | """ 14 | 15 | def __init__(self, **kwargs): 16 | super(LlamaTRTModel, self).__init__(**kwargs) 17 | self.predict_type = kwargs.get("predict_type", "trt") 18 | self.cudaStream = torch.cuda.current_stream().cuda_stream 19 | self.max_seq_len = kwargs.get("max_seq_len", 2048) 20 | for i, inp in enumerate(self.predictor.inputs): 21 | if inp["name"] == "past_key_values": 22 | self.kv_cache_dtype = numpy_to_torch_dtype_dict[inp['dtype']] 23 | self.kv_cache_shape = inp['shape'] 24 | 25 | def create_kv_cache(self, batch_size=1): 26 | self.kv_caches = torch.empty(self.kv_cache_shape[0], self.kv_cache_shape[1], batch_size, self.kv_cache_shape[3], 27 | self.max_seq_len + 1, self.kv_cache_shape[5]).to(self.device, 28 | dtype=self.kv_cache_dtype) 29 | self.kv_ind = 1 30 | 31 | def clear_kv_caches(self): 32 | self.kv_ind = 1 33 | 34 | def get_cur_kv_caches(self): 35 | return self.kv_caches[:, :, :, :, 1:self.kv_ind] 36 | 37 | def input_process(self, *data, **kwargs): 38 | return data 39 | 40 | def output_process(self, *data, **kwargs): 41 | return data[0] 42 | 43 | def predict_trt(self, *data, **kwargs): 44 | nvtx.range_push("forward") 45 | feed_dict = {} 46 | cur_input_shape = None 47 | for i, inp in enumerate(self.predictor.inputs): 48 | if inp["name"] != "past_key_values": 49 | if inp["name"] == "inputs_embeds": 50 | cur_input_shape = data[i].shape 51 | if isinstance(data[i], torch.Tensor): 52 | feed_dict[inp['name']] = data[i].to(device=self.device, 53 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 54 | else: 55 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device, 56 | dtype=numpy_to_torch_dtype_dict[inp['dtype']]) 57 | else: 58 | feed_dict[inp['name']] = self.kv_caches[:, :, :, :, :self.kv_ind] 59 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream) 60 | outs = [] 61 | for i, out in enumerate(self.predictor.outputs): 62 | if out["name"] == "cur_key_values": 63 | out_shape = self.kv_cache_shape[:] 64 | out_shape[2] = cur_input_shape[0] 65 | out_shape[4] = cur_input_shape[1] 66 | out_tensor = preds_dict[out["name"]][:np.prod(out_shape)].reshape(*out_shape) 67 | new_kv_len = out_tensor.shape[4] 68 | self.kv_caches[:, :, :, :, self.kv_ind:self.kv_ind + new_kv_len] = out_tensor.clone() 69 | self.kv_ind += new_kv_len 70 | else: 71 | out_shape = cur_input_shape[:] 72 | out_tensor = preds_dict[out["name"]][:np.prod(out_shape)].reshape(*out_shape) 73 | outs.append(out_tensor.clone()) 74 | nvtx.range_pop() 75 | return outs 76 | 77 | def predict(self, *data, **kwargs): 78 | data = self.input_process(*data, **kwargs) 79 | preds = self.predict_trt(*data, **kwargs) 80 | outputs = self.output_process(*preds, **kwargs) 81 | return outputs 82 | -------------------------------------------------------------------------------- /chattts_plus/trt_models/predictor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch.cuda import nvtx 5 | from collections import OrderedDict 6 | 7 | try: 8 | import tensorrt as trt 9 | import ctypes 10 | except ModuleNotFoundError: 11 | print("No TensorRT Found") 12 | 13 | numpy_to_torch_dtype_dict = { 14 | np.uint8: torch.uint8, 15 | np.int8: torch.int8, 16 | np.int16: torch.int16, 17 | np.int32: torch.int32, 18 | np.int64: torch.int64, 19 | np.float16: torch.float16, 20 | np.float32: torch.float32, 21 | np.float64: torch.float64, 22 | np.complex64: torch.complex64, 23 | np.complex128: torch.complex128, 24 | } 25 | if np.version.full_version >= "1.24.0": 26 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 27 | else: 28 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 29 | 30 | 31 | class TensorRTPredictor: 32 | """ 33 | Implements inference for the EfficientDet TensorRT engine. 34 | """ 35 | 36 | def __init__(self, **kwargs): 37 | """ 38 | :param engine_path: The path to the serialized engine to load from disk. 39 | """ 40 | # Load TRT engine 41 | self.logger = trt.Logger(trt.Logger.ERROR) 42 | trt.init_libnvinfer_plugins(self.logger, "") 43 | engine_path = kwargs.get("model_path", None) 44 | self.debug = kwargs.get("debug", False) 45 | assert engine_path, f"model:{engine_path} must exist!" 46 | with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime: 47 | assert runtime 48 | self.engine = runtime.deserialize_cuda_engine(f.read()) 49 | assert self.engine 50 | # self.context = self.engine.create_execution_context() 51 | # assert self.context 52 | 53 | # Setup I/O bindings 54 | self.inputs = [] 55 | self.outputs = [] 56 | self.tensors = OrderedDict() 57 | 58 | # TODO: 支持动态shape输入 59 | for idx in range(self.engine.num_io_tensors): 60 | name = self.engine[idx] 61 | is_input = self.engine.get_tensor_mode(name).name == "INPUT" 62 | shape = self.engine.get_tensor_shape(name) 63 | dtype = trt.nptype(self.engine.get_tensor_dtype(name)) 64 | binding = { 65 | "index": idx, 66 | "name": name, 67 | "dtype": dtype, 68 | "shape": list(shape) 69 | } 70 | if is_input: 71 | self.inputs.append(binding) 72 | else: 73 | self.outputs.append(binding) 74 | 75 | assert len(self.inputs) > 0 76 | assert len(self.outputs) > 0 77 | self.allocate_max_buffers(**kwargs) 78 | self.activate() 79 | 80 | def activate(self, reuse_device_memory=None): 81 | if reuse_device_memory: 82 | self.context = self.engine.create_execution_context_without_device_memory() 83 | else: 84 | self.context = self.engine.create_execution_context() 85 | 86 | def deactivate(self): 87 | if self.context: 88 | del self.context 89 | self.context = None 90 | 91 | def allocate_max_buffers(self, device="cuda", **kwargs): 92 | nvtx.range_push("allocate_max_buffers") 93 | # 目前仅支持 batch 维度的动态处理, 如果是其他维度的动态,请传入 output_max_shapes 94 | output_max_shapes = kwargs.get("output_max_shapes", {}) 95 | 96 | batch_size = 1 97 | for idx in range(self.engine.num_io_tensors): 98 | binding = self.engine[idx] 99 | shape = self.engine.get_tensor_shape(binding) 100 | is_input = self.engine.get_tensor_mode(binding).name == "INPUT" 101 | if binding in output_max_shapes: 102 | shape = output_max_shapes[binding] 103 | else: 104 | if -1 in shape: 105 | if is_input: 106 | shape = self.engine.get_tensor_profile_shape(binding, 0)[-1] 107 | batch_size = shape[0] 108 | else: 109 | shape[0] = batch_size 110 | dtype = trt.nptype(self.engine.get_tensor_dtype(binding)) 111 | tensor = torch.empty( 112 | np.prod(list(shape)), dtype=numpy_to_torch_dtype_dict[dtype] 113 | ).to(device=device) 114 | self.tensors[binding] = tensor 115 | nvtx.range_pop() 116 | 117 | def input_spec(self): 118 | """ 119 | Get the specs for the input tensor of the network. Useful to prepare memory allocations. 120 | :return: Two items, the shape of the input tensor and its (numpy) datatype. 121 | """ 122 | specs = [] 123 | for i, o in enumerate(self.inputs): 124 | specs.append((o["name"], o['shape'], o['dtype'])) 125 | if self.debug: 126 | print(f"trt input {i} -> {o['name']} -> {o['shape']}") 127 | return specs 128 | 129 | def output_spec(self): 130 | """ 131 | Get the specs for the output tensors of the network. Useful to prepare memory allocations. 132 | :return: A list with two items per element, the shape and (numpy) datatype of each output tensor. 133 | """ 134 | specs = [] 135 | for i, o in enumerate(self.outputs): 136 | specs.append((o["name"], o['shape'], o['dtype'])) 137 | if self.debug: 138 | print(f"trt output {i} -> {o['name']} -> {o['shape']}") 139 | return specs 140 | 141 | def adjust_buffer(self, feed_dict): 142 | nvtx.range_push("adjust_buffer") 143 | for name, buf in feed_dict.items(): 144 | input_tensor = self.tensors[name] 145 | current_shape = list(buf.shape) 146 | if len(current_shape) == 0: 147 | current_shape = (1,) 148 | tensor_len = np.prod(current_shape) 149 | input_tensor[:tensor_len].copy_(buf.reshape(-1)) 150 | self.context.set_input_shape(name, current_shape) 151 | nvtx.range_pop() 152 | 153 | def predict(self, feed_dict, stream): 154 | """ 155 | Execute inference on a batch of images. 156 | :param data: A list of inputs as numpy arrays. 157 | :return A list of outputs as numpy arrays. 158 | """ 159 | nvtx.range_push("set_tensors") 160 | self.adjust_buffer(feed_dict) 161 | for name, tensor in self.tensors.items(): 162 | self.context.set_tensor_address(name, tensor.data_ptr()) 163 | nvtx.range_pop() 164 | nvtx.range_push("execute") 165 | noerror = self.context.execute_async_v3(stream) 166 | if not noerror: 167 | raise ValueError("ERROR: inference failed.") 168 | nvtx.range_pop() 169 | return self.tensors 170 | 171 | def __del__(self): 172 | if self.engine is not None: 173 | del self.engine 174 | self.engine = None 175 | if self.context is not None: 176 | del self.context 177 | self.context = None 178 | del self.inputs 179 | del self.outputs 180 | del self.tensors 181 | 182 | 183 | def get_predictor(**kwargs): 184 | predict_type = kwargs.get("predict_type", "trt") 185 | if predict_type == "trt": 186 | return TensorRTPredictor(**kwargs) 187 | else: 188 | raise NotImplementedError 189 | -------------------------------------------------------------------------------- /configs/accelerate/deepspeed_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | gradient_accumulation_steps: 2 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | enable_cpu_affinity: false 13 | machine_rank: 0 14 | main_training_function: main 15 | mixed_precision: bf16 16 | num_machines: 1 17 | num_processes: 1 18 | rdzv_backend: static 19 | same_network: true 20 | tpu_env: [] 21 | tpu_use_cluster: false 22 | tpu_use_sudo: false 23 | use_cpu: false -------------------------------------------------------------------------------- /configs/infer/chattts_plus.yaml: -------------------------------------------------------------------------------- 1 | MODELS: 2 | tokenizer: 3 | name: "Tokenizer" 4 | infer_type: "pytorch" 5 | kwargs: 6 | model_path: "checkpoints/asset/tokenizer.pt" 7 | dvae_encode: 8 | name: "DVAE" 9 | infer_type: "pytorch" 10 | kwargs: 11 | model_path: "checkpoints/asset/DVAE_full.pt" 12 | dim: 512 13 | decoder_config: 14 | idim: 512 15 | odim: 512 16 | hidden: 256 17 | n_layer: 12 18 | bn_dim: 128 19 | encoder_config: 20 | idim: 512 21 | odim: 1024 22 | hidden: 256 23 | n_layer: 12 24 | bn_dim: 128 25 | vq_config: 26 | dim: 1024 27 | levels: 28 | - 5 29 | - 5 30 | - 5 31 | - 5 32 | G: 2 33 | R: 2 34 | dvae_decode: 35 | name: "DVAE" 36 | infer_type: "pytorch" 37 | kwargs: 38 | model_path: "checkpoints/asset/Decoder.pt" 39 | dim: 384 40 | decoder_config: 41 | idim: 384 42 | odim: 384 43 | hidden: 512 44 | n_layer: 12 45 | bn_dim: 128 46 | vocos: 47 | name: "Vocos" 48 | infer_type: "pytorch" 49 | kwargs: 50 | model_path: "checkpoints/asset/Vocos.pt" 51 | feature_extractor_config: 52 | sample_rate: 24000 53 | n_fft: 1024 54 | hop_length: 256 55 | n_mels: 100 56 | padding: "center" 57 | backbone_config: 58 | input_channels: 100 59 | dim: 512 60 | intermediate_dim: 1536 61 | num_layers: 8 62 | head_config: 63 | dim: 512 64 | n_fft: 1024 65 | hop_length: 256 66 | padding: "center" 67 | gpt: 68 | name: "GPT" 69 | infer_type: "pytorch" 70 | kwargs: 71 | model_path: "checkpoints/asset/GPT.pt" 72 | gpt_config: 73 | hidden_size: 768 74 | intermediate_size: 3072 75 | num_attention_heads: 12 76 | num_hidden_layers: 20 77 | use_cache: False 78 | max_position_embeddings: 4096 79 | spk_emb_dim: 192 80 | spk_KL: False 81 | num_audio_tokens: 626 82 | num_vq: 4 83 | 84 | 85 | -------------------------------------------------------------------------------- /configs/infer/chattts_plus_trt.yaml: -------------------------------------------------------------------------------- 1 | MODELS: 2 | tokenizer: 3 | name: "Tokenizer" 4 | infer_type: "pytorch" 5 | kwargs: 6 | model_path: "checkpoints/asset/tokenizer.pt" 7 | dvae_encode: 8 | name: "DVAE" 9 | infer_type: "pytorch" 10 | kwargs: 11 | model_path: "checkpoints/asset/DVAE_full.pt" 12 | dim: 512 13 | decoder_config: 14 | idim: 512 15 | odim: 512 16 | hidden: 256 17 | n_layer: 12 18 | bn_dim: 128 19 | encoder_config: 20 | idim: 512 21 | odim: 1024 22 | hidden: 256 23 | n_layer: 12 24 | bn_dim: 128 25 | vq_config: 26 | dim: 1024 27 | levels: 28 | - 5 29 | - 5 30 | - 5 31 | - 5 32 | G: 2 33 | R: 2 34 | dvae_decode: 35 | name: "DVAE" 36 | infer_type: "pytorch" 37 | kwargs: 38 | model_path: "checkpoints/asset/Decoder.pt" 39 | dim: 384 40 | decoder_config: 41 | idim: 384 42 | odim: 384 43 | hidden: 512 44 | n_layer: 12 45 | bn_dim: 128 46 | vocos: 47 | name: "Vocos" 48 | infer_type: "pytorch" 49 | kwargs: 50 | model_path: "checkpoints/asset/Vocos.pt" 51 | feature_extractor_config: 52 | sample_rate: 24000 53 | n_fft: 1024 54 | hop_length: 256 55 | n_mels: 100 56 | padding: "center" 57 | backbone_config: 58 | input_channels: 100 59 | dim: 512 60 | intermediate_dim: 1536 61 | num_layers: 8 62 | head_config: 63 | dim: 512 64 | n_fft: 1024 65 | hop_length: 256 66 | padding: "center" 67 | gpt: 68 | name: "GPT" 69 | infer_type: "trt" 70 | kwargs: 71 | model_path: "checkpoints/asset/GPT.pt" 72 | trt_model_path: "checkpoints/chattts_llama.trt" 73 | output_max_shapes: 74 | hidden_states: 75 | - 4 76 | - 2048 77 | - 768 78 | cur_key_values: 79 | - 20 80 | - 2 81 | - 4 82 | - 12 83 | - 2048 84 | - 64 85 | gpt_config: 86 | hidden_size: 768 87 | intermediate_size: 3072 88 | num_attention_heads: 12 89 | num_hidden_layers: 20 90 | use_cache: False 91 | max_position_embeddings: 4096 92 | spk_emb_dim: 192 93 | spk_KL: False 94 | num_audio_tokens: 626 95 | num_vq: 4 96 | 97 | 98 | -------------------------------------------------------------------------------- /configs/train/train_speaker_embedding.yaml: -------------------------------------------------------------------------------- 1 | MODELS: 2 | tokenizer: 3 | name: "Tokenizer" 4 | infer_type: "pytorch" 5 | kwargs: 6 | model_path: "checkpoints/asset/tokenizer.pt" 7 | dvae_encode: 8 | name: "DVAE" 9 | infer_type: "pytorch" 10 | kwargs: 11 | model_path: "checkpoints/asset/DVAE_full.pt" 12 | coef: '' 13 | dim: 512 14 | decoder_config: 15 | idim: 512 16 | odim: 512 17 | hidden: 256 18 | n_layer: 12 19 | bn_dim: 128 20 | encoder_config: 21 | idim: 512 22 | odim: 1024 23 | hidden: 256 24 | n_layer: 12 25 | bn_dim: 128 26 | vq_config: 27 | dim: 1024 28 | levels: 29 | - 5 30 | - 5 31 | - 5 32 | - 5 33 | G: 2 34 | R: 2 35 | dvae_decode: 36 | name: "DVAE" 37 | infer_type: "pytorch" 38 | kwargs: 39 | model_path: "checkpoints/asset/Decoder.pt" 40 | coef: '' 41 | dim: 384 42 | decoder_config: 43 | idim: 384 44 | odim: 384 45 | hidden: 512 46 | n_layer: 12 47 | bn_dim: 128 48 | gpt: 49 | name: "GPT" 50 | infer_type: "pytorch" 51 | kwargs: 52 | model_path: "checkpoints/asset/GPT.pt" 53 | gpt_config: 54 | hidden_size: 768 55 | intermediate_size: 3072 56 | num_attention_heads: 12 57 | num_hidden_layers: 20 58 | use_cache: False 59 | max_position_embeddings: 4096 60 | spk_emb_dim: 192 61 | spk_KL: False 62 | num_audio_tokens: 626 63 | num_vq: 4 64 | 65 | DATA: 66 | train_bs: 4 67 | meta_infos: 68 | - "data/xionger/slicer_opt.list" 69 | sample_rate: 24000 70 | num_vq: 4 71 | 72 | solver: 73 | gradient_accumulation_steps: 1 74 | mixed_precision: 'fp16' 75 | gradient_checkpointing: false 76 | max_train_steps: 40000 77 | max_grad_norm: 1.0 78 | # lr 79 | learning_rate: 1e-2 80 | scale_lr: false 81 | lr_warmup_steps: 10 82 | lr_scheduler: 'constant' 83 | 84 | # optimizer 85 | use_8bit_adam: false 86 | 87 | weight_dtype: 'fp16' 88 | output_dir: './outputs' 89 | exp_name: "xionger_speaker_emb" 90 | speaker_embeds_path: "outputs/xionger_speaker_emb-1732894509.8451874/checkpoints/step-4000/speaker_embeds.pkl" 91 | checkpointing_steps: 100 92 | use_empty_speaker: false -------------------------------------------------------------------------------- /configs/train/train_voice_clone_lora.yaml: -------------------------------------------------------------------------------- 1 | MODELS: 2 | tokenizer: 3 | name: "Tokenizer" 4 | infer_type: "pytorch" 5 | kwargs: 6 | model_path: "checkpoints/asset/tokenizer.pt" 7 | dvae_encode: 8 | name: "DVAE" 9 | infer_type: "pytorch" 10 | kwargs: 11 | model_path: "checkpoints/asset/DVAE_full.pt" 12 | coef: '' 13 | dim: 512 14 | decoder_config: 15 | idim: 512 16 | odim: 512 17 | hidden: 256 18 | n_layer: 12 19 | bn_dim: 128 20 | encoder_config: 21 | idim: 512 22 | odim: 1024 23 | hidden: 256 24 | n_layer: 12 25 | bn_dim: 128 26 | vq_config: 27 | dim: 1024 28 | levels: 29 | - 5 30 | - 5 31 | - 5 32 | - 5 33 | G: 2 34 | R: 2 35 | dvae_decode: 36 | name: "DVAE" 37 | infer_type: "pytorch" 38 | kwargs: 39 | model_path: "checkpoints/asset/Decoder.pt" 40 | coef: '' 41 | dim: 384 42 | decoder_config: 43 | idim: 384 44 | odim: 384 45 | hidden: 512 46 | n_layer: 12 47 | bn_dim: 128 48 | gpt: 49 | name: "GPT" 50 | infer_type: "pytorch" 51 | kwargs: 52 | model_path: "checkpoints/asset/GPT.pt" 53 | gpt_config: 54 | hidden_size: 768 55 | intermediate_size: 3072 56 | num_attention_heads: 12 57 | num_hidden_layers: 20 58 | use_cache: False 59 | max_position_embeddings: 4096 60 | spk_emb_dim: 192 61 | spk_KL: False 62 | num_audio_tokens: 626 63 | num_vq: 4 64 | 65 | DATA: 66 | train_bs: 4 67 | meta_infos: 68 | - "data/leijun/asr_opt/denoise_opt_new.list" 69 | sample_rate: 24000 70 | num_vq: 4 71 | 72 | LORA: 73 | lora_r: 8 74 | lora_alpha: 16 75 | lora_dropout: 0.01 76 | lora_target_modules: 77 | - "q_proj" 78 | - "v_proj" 79 | - "k_proj" 80 | - "o_proj" 81 | # - "gate_proj" 82 | # - "up_proj" 83 | # - "down_proj" 84 | 85 | solver: 86 | gradient_accumulation_steps: 1 87 | mixed_precision: 'fp16' 88 | gradient_checkpointing: false 89 | max_train_steps: 4000 90 | max_grad_norm: 1.0 91 | # lr 92 | learning_rate: 5e-5 93 | min_learning_rate: 1e-5 94 | scale_lr: false 95 | lr_warmup_steps: 10 96 | lr_scheduler: 'constant' 97 | 98 | # optimizer 99 | use_8bit_adam: false 100 | adam_beta1: 0.9 101 | adam_beta2: 0.95 102 | adam_weight_decay: 1e-3 103 | 104 | weight_dtype: 'fp16' 105 | output_dir: './outputs' 106 | exp_name: "leijun_lora" 107 | lora_model_path: "" 108 | checkpointing_steps: 100 109 | use_empty_speaker: true 110 | -------------------------------------------------------------------------------- /demos/notebooklm-podcast/extract_files_to_texts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/3 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: extract_files_to_texts.py 7 | 8 | import datetime 9 | import os 10 | import pdb 11 | 12 | import fitz 13 | from tqdm import tqdm 14 | import os 15 | import cv2 16 | from paddleocr import PPStructure, save_structure_res 17 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 18 | from copy import deepcopy 19 | 20 | 21 | def pdf2images(pdf_file): 22 | image_dir = os.path.splitext(pdf_file)[0] + "_images" 23 | os.makedirs(image_dir, exist_ok=True) 24 | 25 | pdfDoc = fitz.open(pdf_file) 26 | totalPage = pdfDoc.page_count 27 | for pg in tqdm(range(totalPage)): 28 | page = pdfDoc[pg] 29 | rotate = int(0) 30 | zoom_x = 2 31 | zoom_y = 2 32 | mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate) 33 | pix = page.get_pixmap(matrix=mat, alpha=False) 34 | pix.save(os.path.join(image_dir, f"page_{pg + 1:03d}.png")) 35 | print(f"save images of pdf at: {image_dir}") 36 | return image_dir 37 | 38 | 39 | def extract_pdf_to_txt(pdf_file, save_txt_file=None, lang='ch'): 40 | # save_folder = os.path.splitext(pdf_file)[0] + "_txts" 41 | # os.makedirs(save_folder, exist_ok=True) 42 | table_engine = PPStructure(recovery=True, lang=lang, show_log=False) 43 | if save_txt_file is None: 44 | save_txt_file = os.path.splitext(pdf_file)[0] + ".txt" 45 | pdf_image_dir = pdf2images(pdf_file) 46 | text = [] 47 | imgs = sorted(os.listdir(pdf_image_dir)) 48 | for img_name in tqdm(imgs, total=len(imgs)): 49 | img = cv2.imread(os.path.join(pdf_image_dir, img_name)) 50 | result = table_engine(img) 51 | # save_structure_res(result, save_folder, os.path.splitext(img_name)[0]) 52 | h, w, _ = img.shape 53 | res = sorted_layout_boxes(result, w) 54 | # convert_info_docx(img, res, save_folder, os.path.splitext(img_name)[0]) 55 | for line in res: 56 | line.pop('img') 57 | for pra in line['res']: 58 | if isinstance(pra, str): 59 | text.append(pra) 60 | else: 61 | text.append(pra['text']) 62 | text.append('\n') 63 | with open(save_txt_file, 'w', encoding='utf-8') as f: 64 | f.write('\n'.join(text)) 65 | print(f"save txt of pdf at: {save_txt_file}") 66 | return save_txt_file 67 | 68 | 69 | def extract_image_to_txt(image_file, save_txt_file=None, lang='ch'): 70 | save_folder = os.path.splitext(pdf_file)[0] + "_txts" 71 | os.makedirs(save_folder, exist_ok=True) 72 | table_engine = PPStructure(recovery=True, lang=lang, show_log=False) 73 | if save_txt_file is None: 74 | save_txt_file = os.path.splitext(pdf_file)[0] + ".txt" 75 | if os.path.isdir(image_file): 76 | imgs = [os.path.join(image_file, img_) for img_ in os.listdir(image_file) if 77 | img_.split()[-1].lower() in ["jpg", "png", "jpeg"]] 78 | imgs = sorted(imgs) 79 | else: 80 | imgs = [image_file] 81 | text = [] 82 | for img_path in tqdm(imgs, total=len(imgs)): 83 | img = cv2.imread(img_path) 84 | result = table_engine(img) 85 | # save_structure_res(result, save_folder, os.path.splitext(img_name)[0]) 86 | h, w, _ = img.shape 87 | res = sorted_layout_boxes(result, w) 88 | # convert_info_docx(img, res, save_folder, os.path.splitext(img_name)[0]) 89 | for line in res: 90 | line.pop('img') 91 | for pra in line['res']: 92 | text.append(pra['text']) 93 | text.append('\n') 94 | with open(save_txt_file, 'w', encoding='utf-8') as f: 95 | f.write('\n'.join(text)) 96 | print(f"save txt of pdf at: {save_txt_file}") 97 | return save_txt_file 98 | 99 | 100 | if __name__ == '__main__': 101 | pdf_file = "../../data/pdfs/AnimateAnyone.pdf" 102 | # pdf2images(pdf_file) 103 | extract_pdf_to_txt(pdf_file, lang='en') 104 | -------------------------------------------------------------------------------- /demos/notebooklm-podcast/llm_api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/3 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: llm_api.py 7 | 8 | import os 9 | import json 10 | import copy 11 | import base64 12 | import pdb 13 | 14 | from openai import OpenAI 15 | 16 | 17 | def encode_image(image_path): 18 | with open(image_path, "rb") as image_file: 19 | return base64.b64encode(image_file.read()).decode('utf-8') 20 | 21 | 22 | def init_script_writer_prompt(speaker1_gender="woman", speaker2_gender="man"): 23 | operation_history = [] 24 | sysetm_prompt = f""" 25 | You are a renowned podcast scriptwriter, having worked as a ghostwriter for personalities like Joe Rogan, Lex Fridman, Ben Shapiro, and Tim Ferris. Your task is to create a dialogue based on a provided text by user, 26 | incorporating interjections such as "umm," "hmmm," and "right" from the second speaker. The conversation should be highly engaging; occasional digressions are acceptable, 27 | but the discussion should primarily revolve around the main topic. 28 | 29 | The dialogue involves two speakers: 30 | 31 | Speaker 1: This person, a {speaker1_gender}, leads the conversation, teaching Speaker 2 and providing vivid anecdotes and analogies during explanations. They are a charismatic teacher known for their compelling storytelling. 32 | 33 | Speaker 2: This person, a {speaker2_gender}, keeps the conversation flowing with follow-up questions, exhibiting high levels of excitement or confusion. Their inquisitive nature leads them to ask interesting confirmation questions. 34 | 35 | Ensure that any tangents introduced by Speaker 2 are intriguing or entertaining. Interruptions during explanations and interjections like "hmm" and "umm" from Speaker 2 should be present throughout the dialogue. 36 | 37 | Your output should resemble a genuine podcast, with every minute nuance documented in as much detail as possible. Begin with a captivating overview to welcome the listeners, keeping it intriguing and bordering on clickbait. 38 | 39 | Always commence your response with Speaker 1. Do not provide separate episode titles; instead, let Speaker 1 incorporate it into their dialogue. There should be no chapter titles; strictly provide the dialogues. 40 | """ 41 | operation_history.append(["system", [{"type": "text", "text": sysetm_prompt}]]) 42 | return operation_history 43 | 44 | 45 | def init_script_rewriter_prompt(): 46 | operation_history = [] 47 | system_prompt = """ 48 | As an Oscar-winning international screenwriter, you've collaborated with many award-winning podcasters. Your current task is to rewrite the following podcast transcript for an AI Text-To-Speech Pipeline. A rudimentary AI initially wrote this transcript, and it now requires your expertise to make it engaging. 49 | 50 | Two different voice engines will simulate Speaker 1 and Speaker 2. 51 | 52 | Speaker 1: This character leads the conversation and educates Speaker 2. They are known for providing captivating teachings, enriched with compelling anecdotes and analogies. 53 | 54 | Speaker 2: This character maintains the conversation's flow by asking follow-up questions. They exhibit high levels of excitement or confusion and have a curious mindset, often asking intriguing confirmation questions. 55 | 56 | Speaker 2's tangents should be wild or interesting. Interruptions during explanations and interjections like "hmm" and "umm" from Speaker 2 should be present throughout the dialogue. 57 | 58 | Note: The Text-To-Speech engine for Speaker 1 struggles with "umms" and "hmms," so maintain a straight text for this speaker. For Speaker 2, feel free to use "umm," "hmm," [sigh], and [laughs] to convey expressions. 59 | 60 | The output should resemble a genuine podcast, with every minute detail documented meticulously. Begin with a captivating overview to welcome the listeners, keeping it intriguing and bordering on clickbait. 61 | 62 | Your response should start directly with Speaker 1 and strictly be returned as a list of tuples. No additional text should be included outside of the list. 63 | 64 | Example of response: 65 | [ 66 | ("Speaker 1", "Welcome to our podcast, where we explore the latest advancements in AI and technology. I'm your host, and today we're joined by a renowned expert in the field of AI. We're going to dive into the exciting world of Llama 3.2, the latest release from Meta AI."), 67 | ("Speaker 2", "Hi, I'm excited to be here! So, what is Llama 3.2?"), 68 | ("Speaker 1", "Ah, great question! Llama 3.2 is an open-source AI model that allows developers to fine-tune, distill, and deploy AI models anywhere. It's a significant update from the previous version, with improved performance, efficiency, and customization options."), 69 | ("Speaker 2", "That sounds amazing! What are some of the key features of Llama 3.2?") 70 | ] 71 | """ 72 | operation_history.append(["system", [{"type": "text", "text": system_prompt}]]) 73 | return operation_history 74 | 75 | 76 | def inference_openai_chat(messages, model, api_url, token, max_tokens=2048, temperature=0.4, seed=1234): 77 | client = OpenAI( 78 | base_url=api_url, 79 | api_key=token, 80 | ) 81 | 82 | data = { 83 | "model": model, 84 | "messages": [], 85 | "max_tokens": max_tokens, 86 | 'temperature': temperature, 87 | "seed": seed 88 | } 89 | 90 | for role, content in messages: 91 | data["messages"].append({"role": role, "content": content}) 92 | 93 | completion = client.chat.completions.create( 94 | **data 95 | ) 96 | return completion.choices[0].message.content 97 | 98 | 99 | def add_response(role, prompt, chat_history, image=None): 100 | new_chat_history = copy.deepcopy(chat_history) 101 | if image: 102 | base64_image = encode_image(image) 103 | content = [ 104 | { 105 | "type": "text", 106 | "text": prompt 107 | }, 108 | { 109 | "type": "image_url", 110 | "image_url": { 111 | "url": f"data:image/jpeg;base64,{base64_image}" 112 | } 113 | }, 114 | ] 115 | else: 116 | content = [ 117 | { 118 | "type": "text", 119 | "text": prompt 120 | }, 121 | ] 122 | new_chat_history.append([role, content]) 123 | return new_chat_history 124 | 125 | 126 | if __name__ == '__main__': 127 | import utils 128 | import pickle 129 | 130 | base_url = "" 131 | api_token = "" 132 | gpt_model = "gpt-4o-mini" 133 | pdf_txt_file = "../../data/pdfs/AnimateAnyone.txt" 134 | script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script.pkl" 135 | re_script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script-rewrite.pkl" 136 | 137 | chat_writer = init_script_writer_prompt() 138 | pdf_texts = utils.read_file_to_string(pdf_txt_file) 139 | chat_writer = add_response("user", pdf_texts, chat_writer) 140 | output_writer_texts = [] 141 | output_writer = inference_openai_chat(chat_writer, gpt_model, base_url, api_token, max_tokens=8192) 142 | print(output_writer) 143 | chat_writer = add_response("assistant", output_writer, chat_writer) 144 | 145 | with open(script_pkl, 'wb') as file: 146 | pickle.dump(output_writer, file) 147 | 148 | with open(script_pkl, 'rb') as file: 149 | script_texts = pickle.load(file) 150 | chat_rewriter = init_script_rewriter_prompt() 151 | chat_rewriter = add_response("user", script_texts, chat_rewriter) 152 | output_rewriter = inference_openai_chat(chat_rewriter, gpt_model, base_url, api_token, max_tokens=8192) 153 | print(output_rewriter) 154 | 155 | with open(re_script_pkl, 'wb') as file: 156 | pickle.dump(output_rewriter, file) 157 | -------------------------------------------------------------------------------- /demos/notebooklm-podcast/requirements.txt: -------------------------------------------------------------------------------- 1 | openai 2 | PyMuPDF 3 | paddlepaddle 4 | paddleocr -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_man_5200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_5200.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_man_8200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_8200.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_man_9400.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_9400.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_man_9500.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_9500.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_woman_1200.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_1200.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_woman_4600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_4600.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/en_woman_5600.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_5600.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_man_1888.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_1888.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_man_2155.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_2155.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_man_54.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_54.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_woman_1528.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_1528.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_woman_492.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_492.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/speaker_pt/zh_woman_621.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_621.pt -------------------------------------------------------------------------------- /demos/notebooklm-podcast/tts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/3 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: tts.py 7 | import pdb 8 | import random 9 | import torch 10 | import math 11 | import numpy as np 12 | import ast 13 | import torchaudio 14 | from omegaconf import OmegaConf 15 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 16 | from chattts_plus.commons import utils as c_utils 17 | 18 | 19 | cfg = "../../configs/infer/chattts_plus_trt.yaml" 20 | infer_cfg = OmegaConf.load(cfg) 21 | pipe = ChatTTSPlusPipeline(infer_cfg, device=c_utils.get_inference_device()) 22 | 23 | 24 | def generate_audio( 25 | text, 26 | speaker_emb_path, 27 | **kwargs 28 | ): 29 | if not text: 30 | return None 31 | audio_save_path = kwargs.get("audio_save_path", None) 32 | params_infer_code = c_utils.InferCodeParams( 33 | prompt="[speed_3]", 34 | temperature=.0003, 35 | max_new_token=2048, 36 | top_P=0.7, 37 | top_K=20 38 | ) 39 | params_refine_text = c_utils.RefineTextParams( 40 | prompt='[oral_2][laugh_0][break_4]', 41 | top_P=0.7, 42 | top_K=20, 43 | temperature=0.3, 44 | max_new_token=384 45 | ) 46 | infer_seed = kwargs.get("infer_seed", 1234) 47 | with c_utils.TorchSeedContext(infer_seed): 48 | pipe_res_gen = pipe.infer( 49 | text, 50 | params_refine_text=params_refine_text, 51 | params_infer_code=params_infer_code, 52 | use_decoder=True, 53 | stream=False, 54 | skip_refine_text=True, 55 | do_text_normalization=True, 56 | do_homophone_replacement=True, 57 | do_text_optimization=False, 58 | speaker_emb_path=speaker_emb_path, 59 | speaker_audio_path=None, 60 | speaker_audio_text=None 61 | ) 62 | wavs = [] 63 | for wavs_ in pipe_res_gen: 64 | wavs.extend(wavs_) 65 | wavs = torch.cat(wavs).cpu().float().unsqueeze(0) 66 | if audio_save_path: 67 | torchaudio.save(audio_save_path, wavs, 24000) 68 | return wavs 69 | 70 | 71 | if __name__ == '__main__': 72 | import pickle 73 | import os 74 | from tqdm import tqdm 75 | 76 | base_url = "https://api2.aigcbest.top/v1" 77 | api_token = "" 78 | gpt_model = "" 79 | pdf_txt_file = "../../data/pdfs/AnimateAnyone.txt" 80 | script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script.pkl" 81 | re_script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script-rewrite.pkl" 82 | save_audio_path = os.path.splitext(pdf_txt_file)[0] + ".wav" 83 | with open(re_script_pkl, 'rb') as file: 84 | PODCAST_TEXT = pickle.load(file) 85 | 86 | final_audio = None 87 | 88 | i = 1 89 | speaker1_emb_path = "speaker_pt/en_woman_1200.pt" 90 | speaker2_emb_path = "speaker_pt/en_man_5200.pt" 91 | save_dir = os.path.splitext(pdf_txt_file)[0] + "_audios" 92 | os.makedirs(save_dir, exist_ok=True) 93 | wavs = [] 94 | for speaker, text in tqdm(ast.literal_eval(PODCAST_TEXT), desc="Generating podcast segments", unit="segment"): 95 | # output_path = os.path.join(save_dir, f"_podcast_segment_{i:03d}.wav") 96 | output_path = None 97 | if speaker == "Speaker 1": 98 | wav_ = generate_audio(text, speaker1_emb_path, audio_save_path=output_path) 99 | else: 100 | wav_ = generate_audio(text, speaker2_emb_path, audio_save_path=output_path) 101 | wavs.append(wav_) 102 | i += 1 103 | wavs = torch.cat(wavs, dim=-1) 104 | torchaudio.save(save_audio_path, wavs, 24000) 105 | print(save_audio_path) 106 | -------------------------------------------------------------------------------- /demos/notebooklm-podcast/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/3 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: utils.py 7 | 8 | def read_file_to_string(filename): 9 | # Try UTF-8 first (most common encoding for text files) 10 | try: 11 | with open(filename, 'r', encoding='utf-8') as file: 12 | content = file.read() 13 | return content 14 | except UnicodeDecodeError: 15 | # If UTF-8 fails, try with other common encodings 16 | encodings = ['latin-1', 'cp1252', 'iso-8859-1'] 17 | for encoding in encodings: 18 | try: 19 | with open(filename, 'r', encoding=encoding) as file: 20 | content = file.read() 21 | print(f"Successfully read file using {encoding} encoding.") 22 | return content 23 | except UnicodeDecodeError: 24 | continue 25 | 26 | print(f"Error: Could not decode file '{filename}' with any common encoding.") 27 | return None 28 | except FileNotFoundError: 29 | print(f"Error: File '{filename}' not found.") 30 | return None 31 | except IOError: 32 | print(f"Error: Could not read file '{filename}'.") 33 | return None 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<2.0.0 2 | numba 3 | torch>=2.1.0 4 | torchaudio 5 | tqdm 6 | vector_quantize_pytorch==1.17.8 7 | transformers>=4.41.1 8 | vocos 9 | IPython 10 | gradio 11 | pybase16384 12 | pynini==2.1.5; sys_platform == 'linux' 13 | WeTextProcessing; sys_platform == 'linux' 14 | nemo_text_processing; sys_platform == 'linux' 15 | av 16 | pydub 17 | pandas 18 | zh-normalization 19 | transformers 20 | peft 21 | accelerate 22 | sentencepiece 23 | omegaconf 24 | soundfile -------------------------------------------------------------------------------- /scripts/conversions/convert_train_list.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/12/18 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: convert_train_list.py 7 | """ 8 | python scripts/conversions/convert_train_list.py \ 9 | data/leijun/asr_opt/denoise_opt.list \ 10 | data/leijun/asr_opt/denoise_opt_new.list \ 11 | leijun 12 | """ 13 | import argparse 14 | import pdb 15 | 16 | 17 | def convert_list(input_file, output_file, speaker_name): 18 | with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile: 19 | for line in infile: 20 | # 去掉行末的换行符 21 | line = line.strip() 22 | if line: 23 | # 分割行内容 24 | parts = line.split('|') 25 | audio_path = parts[0] 26 | lang = parts[-2] 27 | text = parts[-1] 28 | # 构建新的行,没有空格 29 | new_line = f"{speaker_name}|{audio_path}|{lang}|{text}\n" 30 | outfile.write(new_line) 31 | print(output_file) 32 | 33 | 34 | if __name__ == "__main__": 35 | parser = argparse.ArgumentParser(description='Convert list file format.') 36 | parser.add_argument('input_file', type=str, help='Input list file') 37 | parser.add_argument('output_file', type=str, help='Output list file') 38 | parser.add_argument('speaker_name', type=str, help='Speaker name to add to each line') 39 | 40 | args = parser.parse_args() 41 | 42 | convert_list(args.input_file, args.output_file, args.speaker_name) 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/8/27 22:47 3 | # @Project : ChatTTSPlus 4 | # @FileName: setup.py 5 | 6 | import os 7 | from setuptools import setup, find_packages 8 | 9 | version = "v1.0.0" 10 | 11 | setup( 12 | name="chattts_plus", 13 | version=os.environ.get("CHATTTS_PLUS_VER", version).lstrip("v"), 14 | description="", 15 | long_description=open("README.md", encoding="utf8").read(), 16 | long_description_content_type="text/markdown", 17 | author="wenshao", 18 | author_email="wenshaoguo1026@gmail.com", 19 | url="https://github.com/warmshao/ChatTTSPlus", 20 | packages=[ 21 | 'chattts_plus', 22 | 'chattts_plus.models', 23 | 'chattts_plus.pipelines', 24 | 'chattts_plus.commons', 25 | ], 26 | license="AGPLv3+", 27 | install_requires=[ 28 | "numba", 29 | "numpy<2.0.0", 30 | "pybase16384", 31 | "torch>=2.1.0", 32 | "torchaudio", 33 | "tqdm", 34 | "transformers>=4.41.1", 35 | "vector_quantize_pytorch", 36 | "vocos" 37 | ] 38 | ) 39 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/9/22 15:53 3 | # @Project : ChatTTSPlus 4 | # @FileName: test_models.py 5 | import os 6 | import pdb 7 | import torch 8 | import torchaudio 9 | 10 | 11 | def test_tokenizer(): 12 | from chattts_plus.models.tokenizer import Tokenizer 13 | model_path = "checkpoints/asset/tokenizer.pt" 14 | if torch.cuda.is_available(): 15 | device = torch.device("cuda") 16 | weight_type = torch.float16 17 | else: 18 | device = torch.device("cpu") 19 | weight_type = torch.float32 20 | tokenizer_ = Tokenizer(model_path) 21 | 22 | text = "hello world!" 23 | input_ids, attention_mask, text_mask = tokenizer_.encode([text], 4, None, device) 24 | 25 | 26 | def test_dvae_encode(): 27 | from chattts_plus.models.dvae import DVAE 28 | if torch.cuda.is_available(): 29 | device = torch.device("cuda") 30 | weight_type = torch.float16 31 | else: 32 | device = torch.device("cpu") 33 | weight_type = torch.float32 34 | 35 | audio_file = "data/xionger/slicer_opt/vocal_5.WAV_10.wav_0000251200_0000423680.wav" 36 | audio_wav, audio_sr_ = torchaudio.load(audio_file) 37 | audio_sr = 24000 38 | audio_wav = torchaudio.functional.resample(audio_wav, orig_freq=audio_sr_, new_freq=audio_sr) 39 | audio_wav = torch.mean(audio_wav, 0).to(device, dtype=weight_type) 40 | 41 | decoder_config = dict( 42 | idim=512, 43 | odim=512, 44 | hidden=256, 45 | n_layer=12, 46 | bn_dim=128, 47 | ) 48 | encoder_config = dict( 49 | idim=512, 50 | odim=1024, 51 | hidden=256, 52 | n_layer=12, 53 | bn_dim=128, 54 | ) 55 | vq_config = dict( 56 | dim=1024, 57 | levels=(5, 5, 5, 5), 58 | G=2, 59 | R=2, 60 | ) 61 | model_path = "checkpoints/asset/DVAE_full.pt" 62 | dvae_encoder = DVAE( 63 | decoder_config=decoder_config, 64 | encoder_config=encoder_config, 65 | vq_config=vq_config, 66 | dim=decoder_config["idim"], 67 | model_path=model_path, 68 | coef=None, 69 | ) 70 | dvae_encoder = dvae_encoder.eval().to(device, dtype=weight_type) 71 | audio_ids = dvae_encoder(audio_wav, "encode") 72 | pdb.set_trace() 73 | 74 | 75 | def test_dvae_decode(): 76 | from chattts_plus.models.dvae import DVAE 77 | if torch.cuda.is_available(): 78 | device = torch.device("cuda") 79 | weight_type = torch.float16 80 | else: 81 | device = torch.device("cpu") 82 | weight_type = torch.float32 83 | decoder_config = dict( 84 | idim=384, 85 | odim=384, 86 | hidden=512, 87 | n_layer=12, 88 | bn_dim=128 89 | ) 90 | model_path = "checkpoints/asset/Decoder.pt" 91 | dvae_decoder = DVAE( 92 | decoder_config=decoder_config, 93 | dim=decoder_config["idim"], 94 | coef=None, 95 | model_path=model_path 96 | ) 97 | dvae_decoder = dvae_decoder.eval().to(device, dtype=weight_type) 98 | 99 | vq_feats = torch.randn(1, 768, 388).to(device, dtype=weight_type) 100 | mel_feats = dvae_decoder(vq_feats) 101 | pdb.set_trace() 102 | 103 | 104 | def test_vocos(): 105 | import vocos 106 | import vocos.feature_extractors 107 | import vocos.models 108 | import vocos.heads 109 | 110 | feature_extractor_cfg = dict( 111 | sample_rate=24000, 112 | n_fft=1024, 113 | hop_length=256, 114 | n_mels=100, 115 | padding="center", 116 | ) 117 | backbone_cfg = dict( 118 | input_channels=100, 119 | dim=512, 120 | intermediate_dim=1536, 121 | num_layers=8 122 | ) 123 | head_cfg = dict( 124 | dim=512, 125 | n_fft=1024, 126 | hop_length=256, 127 | padding="center" 128 | ) 129 | feature_extractor = vocos.feature_extractors.MelSpectrogramFeatures(**feature_extractor_cfg) 130 | backbone = vocos.models.VocosBackbone(**backbone_cfg) 131 | head = vocos.heads.ISTFTHead(**head_cfg) 132 | 133 | device = torch.device("cuda") 134 | dtype = torch.float16 135 | if "mps" in str(device): 136 | device = torch.device("cpu") 137 | dtype = torch.float32 138 | 139 | vocos = vocos.Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head).to( 140 | device, dtype=dtype).eval() 141 | vocos_ckpt_path = "checkpoints/asset/Vocos.pt" 142 | vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True)) 143 | 144 | mel_feats = torch.randn(1, 100, 388 * 2).to(device, dtype=dtype) 145 | audio_wavs = vocos.decode(mel_feats).cpu().float() 146 | result_dir = "./results/test_vocos" 147 | os.makedirs(result_dir, exist_ok=True) 148 | torchaudio.save(os.path.join(result_dir, "test.wav"), audio_wavs, sample_rate=24000) 149 | 150 | 151 | def test_gpt(): 152 | from chattts_plus.models.gpt import GPT 153 | 154 | if torch.cuda.is_available(): 155 | device = torch.device("cuda") 156 | weight_type = torch.float16 157 | else: 158 | device = torch.device("cpu") 159 | weight_type = torch.float32 160 | model_path = "checkpoints/asset/GPT.pt" 161 | gpt_cfg = dict( 162 | hidden_size=768, 163 | intermediate_size=3072, 164 | num_attention_heads=12, 165 | num_hidden_layers=20, 166 | use_cache=False, 167 | max_position_embeddings=4096, 168 | spk_emb_dim=192, 169 | spk_KL=False, 170 | num_audio_tokens=626, 171 | num_vq=4, 172 | ) 173 | gpt = GPT(gpt_cfg, model_path=model_path) 174 | 175 | gpt = gpt.eval().to(device, dtype=weight_type) 176 | pdb.set_trace() 177 | 178 | 179 | if __name__ == '__main__': 180 | # test_tokenizer() 181 | test_dvae_encode() 182 | # test_dvae_decode() 183 | # test_vocos() 184 | # test_gpt() 185 | -------------------------------------------------------------------------------- /tests/test_pipelines.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/9/22 12:48 3 | # @Project : ChatTTSPlus 4 | # @FileName: test_pipelines.py 5 | import os 6 | import pdb 7 | 8 | 9 | def test_chattts_plus_pipeline(): 10 | import torch 11 | import time 12 | import torchaudio 13 | 14 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 15 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams 16 | from omegaconf import OmegaConf 17 | 18 | infer_cfg_path = "configs/infer/chattts_plus.yaml" 19 | infer_cfg = OmegaConf.load(infer_cfg_path) 20 | 21 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda")) 22 | 23 | params_infer_code = InferCodeParams( 24 | prompt="[speed_3]", 25 | temperature=.0003, 26 | max_new_token=2048, 27 | top_P=0.7, 28 | top_K=20 29 | ) 30 | params_refine_text = RefineTextParams( 31 | prompt='[oral_2][laugh_3][break_4]', 32 | top_P=0.7, 33 | top_K=20, 34 | temperature=0.7, 35 | max_new_token=384 36 | ) 37 | infer_text = ["我们针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音"] 38 | t0 = time.time() 39 | # leijun: outputs/leijun_lora-1732802535.8597126/checkpoints/step-2000 40 | # xionger: outputs/xionger_lora-1732809910.2932503/checkpoints/step-600 41 | lora_path = "outputs/leijun_lora-1734532984.1128285/checkpoints/step-2000" 42 | pipe_res_gen = pipeline.infer( 43 | infer_text, 44 | params_refine_text=params_refine_text, 45 | params_infer_code=params_infer_code, 46 | use_decoder=True, 47 | stream=False, 48 | skip_refine_text=True, 49 | do_text_normalization=True, 50 | do_homophone_replacement=True, 51 | do_text_optimization=True, 52 | lora_path=lora_path, 53 | speaker_emb_path='' 54 | ) 55 | wavs = [] 56 | for wavs_ in pipe_res_gen: 57 | wavs.extend(wavs_) 58 | print("total infer time:{} sec".format(time.time() - t0)) 59 | save_dir = "results/chattts_plus" 60 | os.makedirs(save_dir, exist_ok=True) 61 | audio_save_path = f"{save_dir}/{os.path.basename(lora_path)}-{time.time()}.wav" 62 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000) 63 | print(audio_save_path) 64 | 65 | 66 | def test_chattts_plus_trt_pipeline(): 67 | import torch 68 | import time 69 | import torchaudio 70 | 71 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 72 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams 73 | from omegaconf import OmegaConf 74 | 75 | infer_cfg_path = "configs/infer/chattts_plus_trt.yaml" 76 | infer_cfg = OmegaConf.load(infer_cfg_path) 77 | 78 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda")) 79 | 80 | speaker_emb_path = "assets/speakers/2222.pt" 81 | params_infer_code = InferCodeParams( 82 | prompt="[speed_5]", 83 | temperature=.0003, 84 | max_new_token=2048, 85 | top_P=0.7, 86 | top_K=20 87 | ) 88 | params_refine_text = RefineTextParams( 89 | prompt='[oral_2][laugh_0][break_4]', 90 | top_P=0.7, 91 | top_K=20, 92 | temperature=0.3, 93 | max_new_token=384 94 | ) 95 | infer_text = [ 96 | "一场雨后,天空和地面互换了身份,抬头万里暗淡,足下星河生辉。这句话真是绝了.你觉得呢.哈哈哈哈", 97 | "本邮件内容是根据招商银行客户提供的个人邮箱发送给其本人的电子邮件,如您并非抬头标明的收件人,请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!" 98 | ] 99 | t0 = time.time() 100 | pipe_res_gen = pipeline.infer( 101 | infer_text, 102 | params_refine_text=params_refine_text, 103 | params_infer_code=params_infer_code, 104 | use_decoder=True, 105 | stream=False, 106 | skip_refine_text=False, 107 | do_text_normalization=True, 108 | do_homophone_replacement=True, 109 | do_text_optimization=True, 110 | speaker_emb_path=speaker_emb_path 111 | ) 112 | wavs = [] 113 | for wavs_ in pipe_res_gen: 114 | wavs.extend(wavs_) 115 | print("total infer time:{} sec".format(time.time() - t0)) 116 | save_dir = "results/chattts_plus" 117 | os.makedirs(save_dir, exist_ok=True) 118 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_emb_path)}-{time.time()}.wav" 119 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000) 120 | print(audio_save_path) 121 | 122 | 123 | def test_chattts_plus_zero_shot_pipeline(): 124 | import torch 125 | import time 126 | import torchaudio 127 | 128 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 129 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams 130 | from omegaconf import OmegaConf 131 | 132 | infer_cfg_path = "configs/infer/chattts_plus.yaml" 133 | infer_cfg = OmegaConf.load(infer_cfg_path) 134 | 135 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda")) 136 | 137 | params_infer_code = InferCodeParams( 138 | prompt="[speed_5]", 139 | temperature=.0003, 140 | max_new_token=2048, 141 | top_P=0.7, 142 | top_K=20 143 | ) 144 | params_refine_text = RefineTextParams( 145 | prompt='[oral_2][laugh_0][break_4]', 146 | top_P=0.7, 147 | top_K=20, 148 | temperature=0.3, 149 | max_new_token=384 150 | ) 151 | infer_text = [ 152 | "一场雨后,天空和地面互换了身份,抬头万里暗淡,足下星河生辉。这句话真是绝了.你觉得呢.哈哈哈哈", 153 | "本邮件内容是根据招商银行客户提供的个人邮箱发送给其本人的电子邮件,如您并非抬头标明的收件人,请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!" 154 | ] 155 | speaker_audio_path = "data/xionger/slicer_opt/vocal_1.WAV_10.wav_0000000000_0000152640.wav" 156 | speaker_audio_text = "嘿嘿,最近我看了寄生虫,真的很推荐哦。" 157 | t0 = time.time() 158 | pipe_res_gen = pipeline.infer( 159 | infer_text, 160 | params_refine_text=params_refine_text, 161 | params_infer_code=params_infer_code, 162 | use_decoder=True, 163 | stream=False, 164 | skip_refine_text=False, 165 | do_text_normalization=True, 166 | do_homophone_replacement=True, 167 | do_text_optimization=True, 168 | speaker_emb_path=None, 169 | speaker_audio_path=speaker_audio_path, 170 | speaker_audio_text=speaker_audio_text 171 | ) 172 | wavs = [] 173 | for wavs_ in pipe_res_gen: 174 | wavs.extend(wavs_) 175 | print("total infer time:{} sec".format(time.time() - t0)) 176 | save_dir = "results/chattts_plus" 177 | os.makedirs(save_dir, exist_ok=True) 178 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_audio_path)}-{time.time()}.wav" 179 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000) 180 | print(audio_save_path) 181 | 182 | 183 | def test_chattts_plus_zero_shot_trt_pipeline(): 184 | import torch 185 | import time 186 | import torchaudio 187 | 188 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 189 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams 190 | from omegaconf import OmegaConf 191 | 192 | infer_cfg_path = "configs/infer/chattts_plus_trt.yaml" 193 | infer_cfg = OmegaConf.load(infer_cfg_path) 194 | 195 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda")) 196 | 197 | params_infer_code = InferCodeParams( 198 | prompt="[speed_5]", 199 | temperature=.0003, 200 | max_new_token=2048, 201 | top_P=0.7, 202 | top_K=20 203 | ) 204 | params_refine_text = RefineTextParams( 205 | prompt='[oral_2][laugh_0][break_4]', 206 | top_P=0.7, 207 | top_K=20, 208 | temperature=0.3, 209 | max_new_token=384 210 | ) 211 | infer_text = [ 212 | "请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!" 213 | ] 214 | speaker_audio_path = "data/xionger/slicer_opt/vocal_1.WAV_10.wav_0000000000_0000152640.wav" 215 | speaker_audio_text = "嘿嘿,最近我看了寄生虫,真的很推荐哦。" 216 | t0 = time.time() 217 | pipe_res_gen = pipeline.infer( 218 | infer_text, 219 | params_refine_text=params_refine_text, 220 | params_infer_code=params_infer_code, 221 | use_decoder=True, 222 | stream=False, 223 | skip_refine_text=False, 224 | do_text_normalization=True, 225 | do_homophone_replacement=True, 226 | do_text_optimization=True, 227 | speaker_emb_path=None, 228 | speaker_audio_path=speaker_audio_path, 229 | speaker_audio_text=speaker_audio_text 230 | ) 231 | wavs = [] 232 | for wavs_ in pipe_res_gen: 233 | wavs.extend(wavs_) 234 | print("total infer time:{} sec".format(time.time() - t0)) 235 | save_dir = "results/chattts_plus" 236 | os.makedirs(save_dir, exist_ok=True) 237 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_audio_path)}-{time.time()}.wav" 238 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000) 239 | print(audio_save_path) 240 | 241 | 242 | if __name__ == '__main__': 243 | test_chattts_plus_pipeline() 244 | # test_chattts_plus_trt_pipeline() 245 | # test_chattts_plus_zero_shot_pipeline() 246 | # test_chattts_plus_zero_shot_trt_pipeline() 247 | -------------------------------------------------------------------------------- /train_lora.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2024/11/23 3 | # @Author : wenshao 4 | # @Email : wenshaoguo1026@gmail.com 5 | # @Project : ChatTTSPlus 6 | # @FileName: train_lora.py 7 | """ 8 | accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml 9 | """ 10 | import logging 11 | import math 12 | import os.path 13 | import pdb 14 | import numpy as np 15 | from datetime import timedelta 16 | from tqdm import tqdm 17 | import peft 18 | import pickle 19 | from accelerate import InitProcessGroupKwargs 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch.utils.checkpoint 24 | import transformers 25 | from accelerate import Accelerator 26 | from accelerate.utils import DistributedDataParallelKwargs 27 | from omegaconf import OmegaConf 28 | import warnings 29 | from peft import LoraConfig, get_peft_model 30 | import time 31 | import pybase16384 as b14 32 | from huggingface_hub import hf_hub_download 33 | from transformers.trainer_pt_utils import LabelSmoother 34 | from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ 35 | from einops import rearrange 36 | from peft import PeftConfig, PeftModel 37 | from chattts_plus.commons.logger import get_logger 38 | from chattts_plus.commons import constants 39 | from chattts_plus.models.tokenizer import Tokenizer 40 | from chattts_plus.models.dvae import DVAE 41 | from chattts_plus.models.gpt import GPT 42 | from chattts_plus.datasets.base_dataset import BaseDataset 43 | from chattts_plus.datasets.collator import BaseCollator 44 | from chattts_plus.commons import norm 45 | 46 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 47 | AUDIO_PAD_TOKEN_ID: int = 0 48 | 49 | warnings.filterwarnings("ignore") 50 | 51 | 52 | def get_mel_attention_mask( 53 | waveform_attention_mask: torch.Tensor, # (batch_size, time) 54 | mel_len: int, 55 | ): 56 | batch_size = waveform_attention_mask.size(0) 57 | mel_attention_mask = torch.ones( 58 | (batch_size, mel_len), 59 | device=waveform_attention_mask.device, 60 | ) 61 | indices = waveform_attention_mask.int().sum(dim=1) # (batch_size,) 62 | indices = torch.ceil(indices * mel_len / waveform_attention_mask.size(1)).int() 63 | for i in range(batch_size): 64 | mel_attention_mask[i, indices[i]:] = 0 65 | return mel_attention_mask # (batch_size, mel_len) 66 | 67 | 68 | def main(cfg): 69 | output_dir = os.path.join(cfg.output_dir, f"{cfg.exp_name}-{time.time()}") 70 | os.makedirs(output_dir, exist_ok=True) 71 | checkpoints_dir = os.path.join(output_dir, "checkpoints") 72 | os.makedirs(checkpoints_dir, exist_ok=True) 73 | log_dir = os.path.join(output_dir, "logs") 74 | logger = get_logger("Lora Training", log_file=os.path.join(log_dir, "train.log")) 75 | logger.info(cfg) 76 | kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=False)] 77 | accelerator = Accelerator( 78 | gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, 79 | mixed_precision=cfg.solver.mixed_precision, 80 | kwargs_handlers=kwargs_handlers, 81 | ) 82 | if accelerator.is_local_main_process: 83 | from torch.utils.tensorboard import SummaryWriter 84 | tf_writer = SummaryWriter(log_dir=os.path.join(output_dir, "tf_logs")) 85 | 86 | # load model 87 | if cfg.weight_dtype == "fp16": 88 | weight_dtype = torch.float16 89 | elif cfg.weight_dtype == "fp32": 90 | weight_dtype = torch.float32 91 | else: 92 | raise ValueError( 93 | f"Do not support weight dtype: {cfg.weight_dtype} during training" 94 | ) 95 | logger.info(f"weight_dtype: {str(weight_dtype)}") 96 | logger.info("loading tokenizer >>>") 97 | tokenizer_kwargs = cfg.MODELS["tokenizer"]["kwargs"] 98 | model_path_org = tokenizer_kwargs["model_path"] 99 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", "")) 100 | if not os.path.exists(model_path_new): 101 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace") 102 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset", 103 | filename=os.path.basename(model_path_new), 104 | local_dir=constants.CHECKPOINT_DIR) 105 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS") 106 | tokenizer_kwargs["model_path"] = model_path_new 107 | tokenizer = Tokenizer(**tokenizer_kwargs) 108 | 109 | # load DVAE encoder 110 | logger.info("loading DVAE encode >>>") 111 | dvae_kwargs = cfg.MODELS["dvae_encode"]["kwargs"] 112 | if not dvae_kwargs["coef"]: 113 | coef_ = torch.rand(100) 114 | coef = b14.encode_to_string(coef_.numpy().astype(np.float32).tobytes()) 115 | dvae_kwargs["coef"] = coef 116 | logger.info(f"Set DAVE Encode Coef: {dvae_kwargs['coef']}") 117 | else: 118 | coef = dvae_kwargs["coef"] 119 | model_path_org = dvae_kwargs["model_path"] 120 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", "")) 121 | if not os.path.exists(model_path_new): 122 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace") 123 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset", 124 | filename=os.path.basename(model_path_new), 125 | local_dir=constants.CHECKPOINT_DIR) 126 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS") 127 | dvae_kwargs["model_path"] = model_path_new 128 | dvae_encoder = DVAE(**dvae_kwargs) 129 | dvae_encoder.eval().to(accelerator.device, dtype=weight_dtype) 130 | dvae_encoder.requires_grad_(False) 131 | 132 | # load DVAE decoder 133 | # logger.info("loading DVAE decode >>>") 134 | # dvae_kwargs = cfg.MODELS["dvae_decode"]["kwargs"] 135 | # if not dvae_kwargs["coef"]: 136 | # dvae_kwargs["coef"] = coef 137 | # logger.info(f"Set DAVE Decode Coef: {dvae_kwargs['coef']}") 138 | # model_path_org = dvae_kwargs["model_path"] 139 | # model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", "")) 140 | # if not os.path.exists(model_path_new): 141 | # logger.info(f"{model_path_new} not exists! Need to download from HuggingFace") 142 | # hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset", 143 | # filename=os.path.basename(model_path_new), 144 | # local_dir=constants.CHECKPOINT_DIR) 145 | # logger.info(f"download {model_path_new} from 2Noise/ChatTTS") 146 | # dvae_kwargs["model_path"] = model_path_new 147 | # dvae_decoder = DVAE(**dvae_kwargs) 148 | # dvae_decoder.eval().to(accelerator.device, dtype=weight_dtype) 149 | # dvae_decoder.requires_grad_(False) 150 | 151 | # Load GPT 152 | logger.info("loading GPT model >>>") 153 | gpt_kwargs = cfg.MODELS["gpt"]["kwargs"] 154 | model_path_org = gpt_kwargs["model_path"] 155 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", "")) 156 | if not os.path.exists(model_path_new): 157 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace") 158 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset", 159 | filename=os.path.basename(model_path_new), 160 | local_dir=constants.CHECKPOINT_DIR) 161 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS") 162 | gpt_kwargs["model_path"] = model_path_new 163 | gpt_model = GPT(**gpt_kwargs) 164 | gpt_model.to(accelerator.device, dtype=weight_dtype) 165 | gpt_model.requires_grad_(False) 166 | 167 | # speaker embedding 168 | spk_stat_path = os.path.join(constants.CHECKPOINT_DIR, "asset/spk_stat.pt") 169 | if not os.path.exists(spk_stat_path): 170 | logger.warning(f"{spk_stat_path} not exists! Need to download from HuggingFace") 171 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset", 172 | filename=os.path.basename(spk_stat_path), 173 | local_dir=constants.CHECKPOINT_DIR) 174 | logger.info(f"download {spk_stat_path} from 2Noise/ChatTTS") 175 | logger.info(f"loading speaker stat: {spk_stat_path}") 176 | assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}" 177 | spk_stat: torch.Tensor = torch.load( 178 | spk_stat_path, 179 | weights_only=True, 180 | mmap=True 181 | ).to(accelerator.device, dtype=weight_dtype) 182 | speaker_std, speaker_mean = spk_stat.chunk(2) 183 | 184 | # dataset 185 | normalizer_json = os.path.join(constants.CHECKPOINT_DIR, "homophones_map.json") 186 | if not os.path.exists(normalizer_json): 187 | logger.warning(f"{normalizer_json} not exists! Need to download from HuggingFace") 188 | hf_hub_download(repo_id="warmshao/ChatTTSPlus", 189 | filename=os.path.basename(normalizer_json), 190 | local_dir=constants.CHECKPOINT_DIR) 191 | logger.info(f"download {normalizer_json} from warmshao/ChatTTSPlus") 192 | logger.info(f"loading normalizer: {normalizer_json}") 193 | normalizer = norm.Normalizer(normalizer_json) 194 | train_dataset = BaseDataset( 195 | meta_infos=cfg.DATA.meta_infos, 196 | sample_rate=cfg.DATA.sample_rate, 197 | num_vq=cfg.DATA.num_vq, 198 | tokenizer=tokenizer, 199 | normalizer=normalizer, 200 | use_empty_speaker=cfg.use_empty_speaker 201 | ) 202 | 203 | # Lora 204 | if cfg.use_empty_speaker: 205 | logger.info("Setting Lora model >>>") 206 | lora_cfg = OmegaConf.to_container(cfg.LORA, resolve=True) 207 | lora_config = LoraConfig( 208 | r=lora_cfg['lora_r'], 209 | lora_alpha=lora_cfg['lora_alpha'], 210 | target_modules=lora_cfg['lora_target_modules'], 211 | lora_dropout=lora_cfg['lora_dropout'] 212 | ) 213 | peft_model = get_peft_model(gpt_model.gpt, lora_config) 214 | peft_model.config.use_cache = False 215 | if cfg.lora_model_path and os.path.exists(cfg.lora_model_path): 216 | logger.info(f"loading lora weight: {cfg.lora_model_path} >>>") 217 | state_dict = None 218 | if cfg.lora_model_path.endswith(".safetensors"): 219 | from safetensors.torch import load_file 220 | state_dict = load_file("model.safetensors") 221 | elif cfg.lora_model_path.endswith(".pth") or cfg.lora_model_path.endswith(".pt"): 222 | state_dict = torch.load(cfg.lora_model_path) 223 | elif os.path.isdir(cfg.lora_model_path): 224 | state_dict = peft.load_peft_weights(cfg.lora_model_path) 225 | else: 226 | logger.error(f"cannot load {cfg.lora_model_path}") 227 | if state_dict is not None: 228 | peft.set_peft_model_state_dict(peft_model, state_dict) 229 | gpt_model.gpt = peft_model 230 | else: 231 | if cfg.speaker_embeds_path: 232 | with open(cfg.speaker_embeds_path, "rb") as fin: 233 | speaker_embeds = pickle.load(fin) 234 | for speaker in speaker_embeds: 235 | spk_emb = torch.from_numpy(speaker_embeds[speaker]).to(accelerator.device, dtype=torch.float32) 236 | spk_emb = torch.nn.Parameter(spk_emb) 237 | spk_emb.requires_grad_(True) 238 | speaker_embeds[speaker] = spk_emb 239 | else: 240 | speaker_embeds = dict() 241 | for speaker in train_dataset.speakers: 242 | if speaker not in speaker_embeds: 243 | dim: int = speaker_std.shape[-1] 244 | spk_emb = torch.randn(dim, device=speaker_std.device, dtype=torch.float32) 245 | spk_emb = torch.nn.Parameter(spk_emb) 246 | spk_emb.requires_grad_(True) 247 | speaker_embeds[speaker] = spk_emb 248 | train_dataloader = torch.utils.data.DataLoader( 249 | train_dataset, batch_size=cfg.DATA.train_bs, shuffle=True, 250 | num_workers=min(cfg.DATA.train_bs, 4), drop_last=True, collate_fn=BaseCollator() 251 | ) 252 | 253 | if cfg.solver.scale_lr: 254 | learning_rate = ( 255 | cfg.solver.learning_rate 256 | * cfg.solver.gradient_accumulation_steps 257 | * cfg.data.train_bs 258 | * accelerator.num_processes 259 | ) 260 | else: 261 | learning_rate = cfg.solver.learning_rate 262 | 263 | # Initialize the optimizer 264 | if cfg.solver.use_8bit_adam: 265 | try: 266 | import bitsandbytes as bnb 267 | except ImportError: 268 | raise ImportError( 269 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 270 | ) 271 | 272 | optimizer_cls = bnb.optim.AdamW8bit 273 | else: 274 | optimizer_cls = torch.optim.AdamW 275 | 276 | trainable_params = list(filter(lambda p: p.requires_grad, gpt_model.gpt.parameters())) 277 | logger.info(f"Total trainable params {len(trainable_params)}") 278 | if cfg.use_empty_speaker: 279 | optimizer = optimizer_cls( 280 | trainable_params, 281 | lr=learning_rate, 282 | betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), 283 | weight_decay=cfg.solver.adam_weight_decay 284 | ) 285 | else: 286 | optimizer = torch.optim.SGD( 287 | list(speaker_embeds.values()), 288 | lr=learning_rate 289 | ) 290 | 291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 292 | num_update_steps_per_epoch = math.ceil( 293 | len(train_dataloader) / cfg.solver.gradient_accumulation_steps 294 | ) 295 | # Afterwards we recalculate our number of training epochs 296 | num_train_epochs = math.ceil( 297 | cfg.solver.max_train_steps / num_update_steps_per_epoch 298 | ) 299 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 300 | # optimizer, num_train_epochs, cfg.solver.min_learning_rate 301 | # ) 302 | lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer) 303 | 304 | # Prepare everything with our `accelerator`. 305 | ( 306 | gpt_model, 307 | optimizer, 308 | train_dataloader, 309 | lr_scheduler, 310 | ) = accelerator.prepare( 311 | gpt_model, 312 | optimizer, 313 | train_dataloader, 314 | lr_scheduler, 315 | ) 316 | 317 | # Train! 318 | total_batch_size = ( 319 | cfg.DATA.train_bs 320 | * accelerator.num_processes 321 | * cfg.solver.gradient_accumulation_steps 322 | ) 323 | 324 | logger.info("***** Running training *****") 325 | logger.info(f" Num examples = {len(train_dataset)}") 326 | logger.info(f" Num Epochs = {num_train_epochs}") 327 | logger.info(f" Instantaneous batch size per device = {cfg.DATA.train_bs}") 328 | logger.info( 329 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 330 | ) 331 | logger.info( 332 | f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" 333 | ) 334 | logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") 335 | global_step = 0 336 | first_epoch = 0 337 | 338 | # Only show the progress bar once on each machine. 339 | progress_bar = tqdm( 340 | range(global_step, cfg.solver.max_train_steps), 341 | disable=not accelerator.is_local_main_process, 342 | ) 343 | progress_bar.set_description("Steps") 344 | 345 | for epoch in range(first_epoch, num_train_epochs): 346 | train_loss = 0.0 347 | for step, batch in enumerate(train_dataloader): 348 | with accelerator.accumulate(gpt_model): 349 | text_input_ids: torch.Tensor = batch[ 350 | "text_input_ids" 351 | ] # (batch_size, text_len, num_vq) 352 | text_attention_mask: torch.Tensor = batch[ 353 | "text_mask" 354 | ] # (batch_size, text_len) 355 | audio_wavs: torch.Tensor = batch["audio_wavs"] # (batch_size, time) 356 | audio_wavs_mask: torch.Tensor = batch["audio_mask"] # (batch_size, time) 357 | 358 | batch_size = text_input_ids.size(0) 359 | text_input_ids = text_input_ids.to(accelerator.device, non_blocking=True) 360 | text_attention_mask = text_attention_mask.to(accelerator.device, dtype=weight_dtype, non_blocking=True) 361 | audio_wavs = audio_wavs.to(accelerator.device, dtype=weight_dtype, non_blocking=True) 362 | audio_wavs_mask = audio_wavs_mask.to(accelerator.device, dtype=weight_dtype, non_blocking=True) 363 | with torch.no_grad(): 364 | # mel_specs = dvae_encoder.preprocessor_mel(audio_wavs) 365 | dvae_audio_input_ids = dvae_encoder(audio_wavs, mode="encode").permute(0, 2, 1).clone() 366 | mel_attention_mask = get_mel_attention_mask(audio_wavs_mask, 367 | mel_len=dvae_audio_input_ids.size(1) * 2) 368 | # if mel_attention_mask.shape[1] > mel_specs.shape[2]: 369 | # mel_attention_mask = mel_attention_mask[:, :mel_specs.shape[2]] 370 | # else: 371 | # mel_specs = mel_specs[:, :, :mel_attention_mask.shape[1]] 372 | # mel_specs = mel_specs * mel_attention_mask.unsqueeze(1) 373 | audio_attention_mask = mel_attention_mask[:, ::2] 374 | dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID 375 | 376 | # add audio eos token 377 | extended_audio_attention_mask = torch.cat( 378 | [ 379 | audio_attention_mask, 380 | torch.zeros( 381 | (batch_size, 1), 382 | dtype=audio_attention_mask.dtype, 383 | device=audio_attention_mask.device, 384 | ), 385 | ], 386 | dim=1, 387 | ) # (batch_size, mel_len+1) 388 | extended_audio_input_ids = torch.cat( 389 | [ 390 | dvae_audio_input_ids, 391 | AUDIO_PAD_TOKEN_ID 392 | * torch.ones( 393 | (batch_size, 1, gpt_model.num_vq), 394 | dtype=dvae_audio_input_ids.dtype, 395 | device=dvae_audio_input_ids.device, 396 | ), 397 | ], 398 | dim=1, 399 | ) # (batch_size, mel_len+1, num_vq) 400 | indices = audio_attention_mask.int().sum(dim=1) # (batch_size,) 401 | AUDIO_EOS_TOKEN_ID = int(gpt_model.emb_code[0].num_embeddings - 1) 402 | for i in range(batch_size): 403 | extended_audio_attention_mask[i, indices[i]] = 1 404 | extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID 405 | 406 | # combine text and audio 407 | input_ids = torch.cat( # (batch_size, text_len + mel_len + 1, num_vq) 408 | [ 409 | text_input_ids, 410 | extended_audio_input_ids, # (batch_size, mel_len, num_vq) 411 | ], 412 | dim=1, 413 | ) 414 | attention_mask = torch.cat( # (batch_size, text_len + mel_len + 1) 415 | [text_attention_mask, extended_audio_attention_mask], 416 | dim=1, 417 | ) 418 | text_mask = torch.cat( # (batch_size, text_len + mel_len + 1) 419 | [ 420 | torch.ones_like(text_attention_mask, dtype=bool), 421 | torch.zeros_like(extended_audio_attention_mask, dtype=bool), 422 | ], 423 | dim=1, 424 | ) 425 | 426 | # set labels 427 | labels = input_ids.detach().clone() # (batch_size, text_len + mel_len + 1, num_vq) 428 | labels[~attention_mask.bool()] = IGNORE_TOKEN_ID 429 | # (batch_size, text_len + mel_len, 768) 430 | inputs_embeds = gpt_model.forward(input_ids=input_ids, text_mask=text_mask) 431 | text_len = text_input_ids.size(1) 432 | if not cfg.use_empty_speaker: 433 | for i, speaker in enumerate(batch['speaker']): 434 | spk_emb = speaker_embeds[speaker].mul(speaker_std.detach()).add(speaker_mean.detach()) 435 | spk_emb = F.normalize(spk_emb, p=2.0, dim=0, eps=1e-12).unsqueeze_(0) 436 | cond = text_input_ids[i].narrow(-1, 0, 1).eq(tokenizer.spk_emb_ids) 437 | inputs_embeds[i, :text_len] = torch.where(cond, spk_emb.to(inputs_embeds.dtype), 438 | inputs_embeds[i, :text_len]) 439 | outputs = gpt_model.gpt.forward(inputs_embeds=inputs_embeds.to(dtype=weight_dtype), 440 | attention_mask=attention_mask.to(dtype=weight_dtype)) 441 | hidden_states = outputs.last_hidden_state 442 | 443 | audio_hidden_states = hidden_states[ 444 | :, text_len - 1: -1 445 | ] # (batch_size, mel_len+1, 768) 446 | audio_labels = labels[:, text_len:] 447 | audio_logits = torch.stack( 448 | [ 449 | gpt_model.head_code[i](audio_hidden_states) 450 | for i in range(gpt_model.num_vq) 451 | ], 452 | dim=2, 453 | ) # (batch_size, mel_len+1, num_vq, num_class_audio) 454 | 455 | audio_loss = torch.nn.functional.cross_entropy( 456 | audio_logits.flatten(0, 2), audio_labels.flatten(0, 2), ignore_index=IGNORE_TOKEN_ID 457 | ) 458 | loss = audio_loss 459 | 460 | with torch.no_grad(): 461 | # Get predictions 462 | predictions = audio_logits.flatten(0, 2).argmax(dim=-1) # (batch_size * mel_len * num_vq) 463 | labels_flat = audio_labels.flatten(0, 2) # (batch_size * mel_len * num_vq) 464 | # Create mask for valid tokens (not IGNORE_TOKEN_ID) 465 | valid_mask = (labels_flat != IGNORE_TOKEN_ID) 466 | 467 | # Calculate accuracy only on valid tokens 468 | correct = (predictions[valid_mask] == labels_flat[valid_mask]).float() 469 | accuracy = correct.mean() if valid_mask.any() else torch.tensor(0.0).to(correct.device) 470 | 471 | # Gather the accuracy across all processes 472 | avg_accuracy = accelerator.gather(accuracy.repeat(cfg.DATA.train_bs)).mean() 473 | 474 | # Gather the losses across all processes for logging (if we use distributed training). 475 | avg_loss = accelerator.gather(loss.repeat(cfg.DATA.train_bs)).mean() 476 | train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps 477 | train_accuracy = avg_accuracy.item() 478 | 479 | # Backpropagate 480 | accelerator.backward(loss) 481 | # if accelerator.sync_gradients: 482 | # accelerator.clip_grad_norm_( 483 | # trainable_params, 484 | # cfg.solver.max_grad_norm, 485 | # ) 486 | optimizer.step() 487 | lr_scheduler.step() 488 | optimizer.zero_grad() 489 | 490 | if accelerator.sync_gradients: 491 | progress_bar.update(1) 492 | global_step += 1 493 | accelerator.log({"train_loss": train_loss, "train_acc": train_accuracy}, step=global_step) 494 | if accelerator.is_main_process: 495 | tf_writer.add_scalar('train_loss', train_loss, global_step) 496 | tf_writer.add_scalar('train_acc', train_accuracy, global_step) 497 | tf_writer.add_scalar('train_audio_loss', audio_loss.detach().item(), global_step) 498 | train_loss = 0.0 499 | 500 | if global_step == 1 or global_step % cfg.checkpointing_steps == 0: 501 | if accelerator.is_main_process: 502 | step_checkpoint_dir = os.path.join(checkpoints_dir, f"step-{global_step}") 503 | os.makedirs(step_checkpoint_dir, exist_ok=True) 504 | if not cfg.use_empty_speaker: 505 | for spk_name in speaker_embeds: 506 | spk_emb = speaker_embeds[speaker].detach().mul(speaker_std).add(speaker_mean) 507 | spk_emb = tokenizer._encode_spk_emb(spk_emb) 508 | output_path = os.path.join(step_checkpoint_dir, f"{spk_name}.pt") 509 | torch.save(spk_emb, output_path) 510 | 511 | speaker_embeds_w = {} 512 | for speaker in speaker_embeds: 513 | speaker_embeds_w[speaker] = speaker_embeds[speaker].detach().float().cpu().data.numpy() 514 | with open(os.path.join(step_checkpoint_dir, "speaker_embeds.pkl"), "wb") as fw: 515 | pickle.dump(speaker_embeds_w, fw) 516 | else: 517 | unwrap_net = accelerator.unwrap_model(gpt_model) 518 | unwrap_net.gpt.save_pretrained(step_checkpoint_dir) 519 | 520 | logs = { 521 | "loss": loss.detach().item(), 522 | "audio_loss": audio_loss.detach().item(), 523 | "step_acc": train_accuracy, 524 | "lr": lr_scheduler.get_last_lr()[0] 525 | } 526 | progress_bar.set_postfix(**logs) 527 | 528 | if global_step >= cfg.solver.max_train_steps: 529 | break 530 | 531 | accelerator.wait_for_everyone() 532 | accelerator.end_training() 533 | 534 | 535 | if __name__ == "__main__": 536 | import argparse 537 | 538 | parser = argparse.ArgumentParser() 539 | parser.add_argument("--config", type=str, default="./configs/train/train_voice_clone_lora.yaml") 540 | args = parser.parse_args() 541 | config = OmegaConf.load(args.config) 542 | main(config) 543 | -------------------------------------------------------------------------------- /update.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | git fetch origin 3 | git reset --hard origin/master 4 | "venv\Scripts\pip.exe" install -r requirements.txt 5 | pause -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | ".\venv\python.exe" webui.py --cfg .\configs\infer\chattts_plus_trt.yaml --server_port 7890 3 | pause -------------------------------------------------------------------------------- /webui.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import uuid 4 | 5 | if sys.platform == "darwin": 6 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" 7 | import random 8 | import argparse 9 | import gradio as gr 10 | from omegaconf import OmegaConf 11 | import numpy as np 12 | import math 13 | import torch 14 | import subprocess 15 | import shutil 16 | 17 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline 18 | from chattts_plus.commons import utils 19 | from chattts_plus.commons import constants 20 | 21 | # ChatTTSPlus pipeline 22 | pipe: ChatTTSPlusPipeline = None 23 | 24 | js_func = """ 25 | function refresh() { 26 | const url = new URL(window.location); 27 | 28 | if (url.searchParams.get('__theme') !== 'dark') { 29 | url.searchParams.set('__theme', 'dark'); 30 | window.location.href = url.href; 31 | } 32 | } 33 | """ 34 | 35 | seed_min = 1 36 | seed_max = 4294967295 37 | 38 | 39 | def generate_seed(): 40 | return gr.update(value=random.randint(seed_min, seed_max)) 41 | 42 | 43 | def update_spk_emb_path(file): 44 | spk_emb_path = file.name 45 | return spk_emb_path 46 | 47 | 48 | def update_spk_lora_path(files): 49 | spk_lora_path = os.path.join(constants.CHECKPOINT_DIR, "lora", str(uuid.uuid4())) 50 | os.makedirs(spk_lora_path, exist_ok=True) 51 | is_valid = 0 52 | for file in files: 53 | if file.name.endswith(".json"): 54 | shutil.copy(file.name, os.path.join(spk_lora_path, "adapter_config.json")) 55 | is_valid += 1 56 | elif file.name.endswith(".safetensors"): 57 | shutil.copy(file.name, os.path.join(spk_lora_path, "adapter_model.safetensors")) 58 | is_valid += 1 59 | if is_valid == 2: 60 | return spk_lora_path 61 | else: 62 | return '' 63 | 64 | 65 | def list_pt_files_in_dir(directory): 66 | if os.path.isdir(directory): 67 | pt_files = [f for f in os.listdir(directory) if f.endswith('.pt')] 68 | return gr.Dropdown(label="Select Speaker Embedding", choices=pt_files) if pt_files else gr.Dropdown( 69 | label="Select Speaker Embedding", choices=[]) 70 | return gr.Dropdown(label="Select Speaker Embedding", choices=[]) 71 | 72 | 73 | def set_spk_emb_path_from_dir(directory, selected_file): 74 | spk_emb_path = os.path.join(directory, selected_file) 75 | return spk_emb_path 76 | 77 | 78 | def float_to_int16(audio: np.ndarray) -> np.ndarray: 79 | am = int(math.ceil(float(np.abs(audio).max())) * 32768) 80 | am = 32767 * 32768 // am 81 | return np.multiply(audio, am).astype(np.int16) 82 | 83 | 84 | def refine_text( 85 | text, 86 | prompt, 87 | temperature, 88 | top_P, 89 | top_K, 90 | text_seed_input, 91 | refine_text_flag, 92 | ): 93 | global pipe 94 | 95 | if not refine_text_flag: 96 | return text 97 | 98 | with utils.TorchSeedContext(text_seed_input): 99 | params_refine_text = utils.RefineTextParams( 100 | prompt=prompt, 101 | top_P=top_P, 102 | top_K=top_K, 103 | temperature=temperature, 104 | max_new_token=384 105 | ) 106 | text_gen = pipe.infer( 107 | text, 108 | skip_refine_text=False, 109 | refine_text_only=True, 110 | do_text_normalization=True, 111 | do_homophone_replacement=True, 112 | do_text_optimization=True, 113 | params_refine_text=params_refine_text 114 | ) 115 | texts = [] 116 | for text_ in text_gen: 117 | texts.extend(text_) 118 | 119 | return "\n".join(texts) 120 | 121 | 122 | def generate_audio( 123 | text, 124 | prompt, 125 | temperature, 126 | top_P, 127 | top_K, 128 | spk_emb_path, 129 | stream, 130 | audio_seed_input, 131 | sample_text_input, 132 | sample_audio_input, 133 | spk_lora_path 134 | ): 135 | global pipe 136 | 137 | if not text: 138 | return None 139 | 140 | params_infer_code = utils.InferCodeParams( 141 | prompt=prompt, 142 | temperature=temperature, 143 | top_P=top_P, 144 | top_K=top_K, 145 | max_new_token=2048, 146 | ) 147 | 148 | with utils.TorchSeedContext(audio_seed_input): 149 | wav_gen = pipe.infer( 150 | text, 151 | skip_refine_text=True, 152 | do_text_normalization=False, 153 | do_homophone_replacement=False, 154 | do_text_optimization=False, 155 | params_infer_code=params_infer_code, 156 | stream=stream, 157 | speaker_audio_path=sample_audio_input, 158 | speaker_audio_text=sample_text_input, 159 | speaker_emb_path=spk_emb_path, 160 | lora_path=spk_lora_path 161 | ) 162 | if stream: 163 | for gen in wav_gen: 164 | audio = gen[0].cpu().float().numpy() 165 | if audio is not None and len(audio) > 0: 166 | yield 24000, float_to_int16(audio).T 167 | else: 168 | wavs = [] 169 | for wavs_ in wav_gen: 170 | wavs.extend(wavs_) 171 | wavs = torch.cat(wavs).cpu().float().numpy() 172 | yield 24000, float_to_int16(wavs).T 173 | 174 | 175 | def update_active_tab(tab_name): 176 | return gr.State(tab_name) 177 | 178 | 179 | # 清空 Speaker Embedding Path 的重置函数 180 | def reset_spk_emb_path(): 181 | return "", [] 182 | 183 | 184 | # 清空 Sample Audio 和 Sample Text 的重置函数 185 | def reset_sample_inputs(): 186 | return None, "" # 返回 None 清空音频,空字符串清空文本框 187 | 188 | 189 | def reset_lora_inputs(): 190 | return None, "" # 返回 None 清空音频,空字符串清空文本框 191 | 192 | 193 | def main(args): 194 | with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")]), js=js_func) as demo: 195 | gr.Markdown("# ChatTTSPlus WebUI") 196 | with gr.Row(): 197 | with gr.Column(): 198 | text_input = gr.Textbox( 199 | label="Input Text", 200 | lines=4, 201 | max_lines=4, 202 | placeholder="Please Input Text...", 203 | interactive=True, 204 | ) 205 | activate_tag_name = gr.State(value="Speaker Embedding") 206 | 207 | with gr.Tabs() as tabs: 208 | with gr.Tab("Speaker Embedding"): 209 | with gr.Column(): 210 | with gr.Row(equal_height=True): 211 | with gr.Column(): 212 | spk_emb_dir = gr.Textbox(label="Input Speaker Embedding Directory", 213 | placeholder="Please input speaker embedding directory", 214 | value=os.path.abspath( 215 | os.path.join(constants.PROJECT_DIR, "assets/speakers"))) 216 | reload_chat_button = gr.Button("Reload", scale=1) 217 | pt_files_dropdown = gr.Dropdown(label="Select Speaker Embedding") 218 | 219 | upload_emb_file = gr.File(label="Upload Speaker Embedding File (.pt)") 220 | 221 | spk_emb_path = gr.Textbox( 222 | label="Speaker Embedding Path", 223 | max_lines=3, 224 | show_copy_button=True, 225 | interactive=True, 226 | scale=2, 227 | ) 228 | spk_emb_reset = gr.Button("Reset", scale=1) 229 | 230 | upload_emb_file.upload(update_spk_emb_path, inputs=upload_emb_file, outputs=spk_emb_path) 231 | reload_chat_button.click( 232 | list_pt_files_in_dir, inputs=spk_emb_dir, outputs=pt_files_dropdown 233 | ) 234 | pt_files_dropdown.select( 235 | set_spk_emb_path_from_dir, inputs=[spk_emb_dir, pt_files_dropdown], outputs=spk_emb_path 236 | ) 237 | 238 | # 点击 Reset 按钮清空 Speaker Embedding Path 239 | spk_emb_reset.click( 240 | reset_spk_emb_path, inputs=None, outputs=[spk_emb_path, pt_files_dropdown] 241 | ) 242 | 243 | with gr.Tab("Speaker Audio (ZeroShot)"): 244 | with gr.Column(): 245 | with gr.Row(equal_height=True): 246 | sample_audio_input = gr.Audio( 247 | value=None, 248 | type="filepath", 249 | interactive=True, 250 | show_label=False, 251 | waveform_options=gr.WaveformOptions( 252 | sample_rate=24000, 253 | ), 254 | ) 255 | sample_text_input = gr.Textbox( 256 | label="Sample Text (ZeroShot)", 257 | lines=4, 258 | max_lines=4, 259 | placeholder="If Sample Audio and Sample Text are available, the Speaker Embedding will be disabled.", 260 | interactive=True, 261 | ) 262 | sample_reset = gr.Button("Reset", scale=1) 263 | # 点击 Reset 按钮清空 Sample Audio 和 Sample Text 264 | sample_reset.click( 265 | reset_sample_inputs, inputs=None, outputs=[sample_audio_input, sample_text_input] 266 | ) 267 | 268 | with gr.Tab("Speaker Lora"): 269 | with gr.Column(): 270 | lora_files = gr.Files(label="Lora files: config.json and safetensors") 271 | spk_lora_path = gr.Textbox( 272 | label="Speaker Lora Path", 273 | max_lines=3, 274 | show_copy_button=True, 275 | interactive=True, 276 | scale=2, 277 | ) 278 | lora_files.upload(update_spk_lora_path, inputs=lora_files, outputs=spk_lora_path) 279 | lora_reset = gr.Button("Reset", scale=1) 280 | lora_reset.click( 281 | reset_lora_inputs, inputs=None, outputs=[lora_files, spk_lora_path] 282 | ) 283 | 284 | with gr.Row(equal_height=True): 285 | refine_text_checkbox = gr.Checkbox( 286 | label="Refine text", interactive=True, value=True 287 | ) 288 | text_prompt = gr.Text( 289 | interactive=True, 290 | value="[oral_2][laugh_0][break_4]", 291 | label="text_prompt" 292 | ) 293 | text_temperature_slider = gr.Number( 294 | minimum=0.00001, 295 | maximum=1.0, 296 | value=0.3, 297 | step=0.05, 298 | label="Text Temperature", 299 | interactive=True, 300 | ) 301 | text_top_p_slider = gr.Number( 302 | minimum=0.1, 303 | maximum=0.9, 304 | value=0.7, 305 | step=0.05, 306 | label="text_top_P", 307 | interactive=True, 308 | ) 309 | text_top_k_slider = gr.Number( 310 | minimum=1, 311 | maximum=30, 312 | value=20, 313 | step=1, 314 | label="text_top_K", 315 | interactive=True, 316 | ) 317 | text_seed_input = gr.Number( 318 | label="Text Seed", 319 | interactive=True, 320 | value=1, 321 | minimum=seed_min, 322 | maximum=seed_max, 323 | ) 324 | generate_text_seed = gr.Button("\U0001F3B2", interactive=True) 325 | 326 | with gr.Row(equal_height=True): 327 | audio_prompt = gr.Text( 328 | interactive=True, 329 | value="[speed_4]", 330 | label="audio_prompt" 331 | ) 332 | audio_temperature_slider = gr.Number( 333 | minimum=0.00001, 334 | maximum=1.0, 335 | step=0.0001, 336 | value=0.0003, 337 | label="Audio Temperature", 338 | interactive=True, 339 | ) 340 | audio_top_p_slider = gr.Number( 341 | minimum=0.1, 342 | maximum=0.9, 343 | value=0.7, 344 | step=0.05, 345 | label="audio_top_P", 346 | interactive=True, 347 | ) 348 | audio_top_k_slider = gr.Number( 349 | minimum=1, 350 | maximum=100, 351 | value=20, 352 | step=1, 353 | label="audio_top_K", 354 | interactive=True, 355 | ) 356 | audio_seed_input = gr.Number( 357 | label="Audio Seed", 358 | interactive=True, 359 | value=1, 360 | minimum=seed_min, 361 | maximum=seed_max, 362 | ) 363 | generate_audio_seed = gr.Button("\U0001F3B2", interactive=True) 364 | 365 | with gr.Row(): 366 | auto_play_checkbox = gr.Checkbox( 367 | label="Auto Play", value=False, scale=1, interactive=True 368 | ) 369 | stream_mode_checkbox = gr.Checkbox( 370 | label="Stream Mode", 371 | value=False, 372 | scale=1, 373 | interactive=True, 374 | ) 375 | generate_button = gr.Button( 376 | "Generate", scale=2, variant="primary", interactive=True 377 | ) 378 | interrupt_button = gr.Button( 379 | "Interrupt", 380 | scale=2, 381 | variant="stop", 382 | visible=False, 383 | interactive=False, 384 | ) 385 | 386 | text_output = gr.Textbox( 387 | label="Output Text", 388 | interactive=False, 389 | show_copy_button=True, 390 | lines=4, 391 | ) 392 | 393 | generate_audio_seed.click(generate_seed, outputs=audio_seed_input) 394 | generate_text_seed.click(generate_seed, outputs=text_seed_input) 395 | 396 | @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox]) 397 | def make_audio(autoplay, stream): 398 | audio_output = gr.Audio( 399 | label="Output Audio", 400 | value=None, 401 | format="wav", 402 | autoplay=autoplay, 403 | streaming=stream, 404 | interactive=False, 405 | show_label=True, 406 | waveform_options=gr.WaveformOptions( 407 | sample_rate=24000, 408 | ), 409 | show_download_button=True 410 | ) 411 | 412 | generate_button.click( 413 | fn=refine_text, 414 | inputs=[ 415 | text_input, 416 | text_prompt, 417 | text_temperature_slider, 418 | text_top_p_slider, 419 | text_top_k_slider, 420 | text_seed_input, 421 | refine_text_checkbox, 422 | ], 423 | outputs=text_output, 424 | ).then( 425 | generate_audio, 426 | inputs=[ 427 | text_output, 428 | audio_prompt, 429 | audio_temperature_slider, 430 | audio_top_p_slider, 431 | audio_top_k_slider, 432 | spk_emb_path, 433 | stream_mode_checkbox, 434 | audio_seed_input, 435 | sample_text_input, 436 | sample_audio_input, 437 | spk_lora_path 438 | ], 439 | outputs=audio_output, 440 | ) 441 | 442 | demo.launch( 443 | server_name=args.server_name, 444 | server_port=args.server_port, 445 | inbrowser=True, 446 | show_api=False, 447 | ) 448 | 449 | 450 | if __name__ == "__main__": 451 | parser = argparse.ArgumentParser(description="ChatTTS demo Launch") 452 | parser.add_argument( 453 | "--cfg", type=str, default="configs/infer/chattts_plus.yaml", help="config of chattts plus" 454 | ) 455 | parser.add_argument( 456 | "--server_name", type=str, default="0.0.0.0", help="server name" 457 | ) 458 | parser.add_argument("--server_port", type=int, default=7890, help="server port") 459 | args = parser.parse_args() 460 | 461 | infer_cfg = OmegaConf.load(args.cfg) 462 | pipe = ChatTTSPlusPipeline(infer_cfg, device=utils.get_inference_device()) 463 | main(args) 464 | --------------------------------------------------------------------------------