├── .gitignore
├── LICENSE
├── README.md
├── README_ZH.md
├── assets
├── docs
│ ├── voice_clone.md
│ └── voice_clone_ZH.md
└── speakers
│ └── 2222.pt
├── chattts_plus
├── __init__.py
├── commons
│ ├── __init__.py
│ ├── constants.py
│ ├── logger.py
│ ├── norm.py
│ ├── onnx2trt.py
│ ├── text_utils.py
│ └── utils.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ └── collator.py
├── models
│ ├── __init__.py
│ ├── dvae.py
│ ├── gpt.py
│ ├── llama.py
│ ├── processors.py
│ └── tokenizer.py
├── pipelines
│ ├── __init__.py
│ └── chattts_plus_pipeline.py
└── trt_models
│ ├── __init__.py
│ ├── base_model.py
│ ├── gpt_trt.py
│ ├── llama_trt_model.py
│ └── predictor.py
├── configs
├── accelerate
│ └── deepspeed_config.yaml
├── infer
│ ├── chattts_plus.yaml
│ └── chattts_plus_trt.yaml
└── train
│ ├── train_speaker_embedding.yaml
│ └── train_voice_clone_lora.yaml
├── demos
└── notebooklm-podcast
│ ├── extract_files_to_texts.py
│ ├── llm_api.py
│ ├── requirements.txt
│ ├── speaker_pt
│ ├── en_man_5200.pt
│ ├── en_man_8200.pt
│ ├── en_man_9400.pt
│ ├── en_man_9500.pt
│ ├── en_woman_1200.pt
│ ├── en_woman_4600.pt
│ ├── en_woman_5600.pt
│ ├── zh_man_1888.pt
│ ├── zh_man_2155.pt
│ ├── zh_man_54.pt
│ ├── zh_woman_1528.pt
│ ├── zh_woman_492.pt
│ └── zh_woman_621.pt
│ ├── tts.py
│ └── utils.py
├── requirements.txt
├── scripts
└── conversions
│ └── convert_train_list.py
├── setup.py
├── tests
├── test_models.py
└── test_pipelines.py
├── train_lora.py
├── update.bat
├── webui.bat
└── webui.py
/.gitignore:
--------------------------------------------------------------------------------
1 | results
2 | checkpoints
3 | .idea
4 | data
5 | *.egg-info
6 | .DS_store
7 | logs
8 | */__pycache__/
9 | scripts/conversions
10 | *.pyc
11 | venv
12 | dist
13 | build
14 | outputs
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## ChatTTSPlus: Extension of ChatTTS
2 |
3 | 中文 | English
4 |
5 | ChatTTSPlus is an extension of [ChatTTS](https://github.com/2noise/ChatTTS), adding features such as TensorRT acceleration, voice cloning, and mobile model deployment.
6 |
7 | **If you find this project useful, please give it a star! ✨✨**
8 |
9 | ### Some fun demos based on ChatTTSPlus
10 | * NotebookLM podcast: [](https://colab.research.google.com/drive/1jz8HPdRe_igoNjMSv0RaTn3l2c3seYFT?usp=sharing).
11 | Use ChatTTSPlus to turn the `AnimateAnyone` paper into a podcast.
12 |
13 |
14 |
15 | ### New Features
16 | - [x] Refactored ChatTTS code in a way I'm familiar with.
17 | - [x] **Achieved over 3x acceleration with TensorRT**, increasing performance on a Windows 3060 GPU from 28 tokens/s to 110 tokens/s.
18 | - [x] Windows integration package for one-click extraction and use.
19 | - [x] Implemented voice cloning using technologies like LoRA. Please reference [voice_clone](assets/docs/voice_clone.md).
20 | - [ ] Model compression and acceleration using techniques like pruning and knowledge distillation, targeting mobile deployment.
21 |
22 | ### Environment Setup
23 | * Install Python 3; it's recommended to use [Miniforge](https://github.com/conda-forge/miniforge). Run: `conda create -n chattts_plus python=3.10 && conda activate chattts_plus`
24 | * Download the source code: `git clone https://github.com/warmshao/ChatTTSPlus`, and navigate to the project root directory: `cd ChatTTSPlus`
25 | * Install necessary Python libraries: `pip install -r requirements.txt`
26 | * [Optional] If you want to use TensorRT, please install [tensorrt10](https://developer.nvidia.com/tensorrt/download)
27 | * [Recommended for Windows users] Download the integration package directly from [Google Drive Link](https://drive.google.com/file/d/1yOnU5dRTJvFnc4wyw02nAeJH5_FgNod2/view?usp=sharing), extract it, and double-click `webui.bat` to use. If you want to update the code, please double-click `update.bat`. Note: **This will overwrite all your local code modifications.**
28 |
29 | ### Demo
30 | * Web UI with TensorRT: `python webui.py --cfg configs/infer/chattts_plus_trt.yaml`.
31 | * Web UI with PyTorch: `python webui.py --cfg configs/infer/chattts_plus.yaml`
32 |
33 |
34 |
35 | ### License
36 | ChatTTSPlus inherits the license from [ChatTTS](https://github.com/2noise/ChatTTS); please refer to [ChatTTS](https://github.com/2noise/ChatTTS) as the standard.
37 |
38 | The code is published under the AGPLv3+ license.
39 |
40 | The model is published under the CC BY-NC 4.0 license. It is intended for educational and research use and should not be used for any commercial or illegal purposes. The authors do not guarantee the accuracy, completeness, or reliability of the information. The information and data used in this repository are for academic and research purposes only. The data is obtained from publicly available sources, and the authors do not claim any ownership or copyright over the data.
41 |
42 | ### Disclaimer
43 | We do not hold any responsibility for any illegal usage of the codebase. Please refer to your local laws about DMCA and other related laws.
--------------------------------------------------------------------------------
/README_ZH.md:
--------------------------------------------------------------------------------
1 | ## ChatTTSPlus: Extension of ChatTTS
2 |
3 | 中文 | English
4 |
5 | ChatTTSPlus是[ChatTTS](https://github.com/2noise/ChatTTS)的扩展,增加使用TensorRT加速、声音克隆和模型移动端运行等功能。
6 |
7 | **如果你觉得这个项目有用,帮我点个star吧✨✨**
8 |
9 | ### 基于ChatTTSPlus做的有趣的demo
10 | * NotebookLM播客: [](https://colab.research.google.com/drive/1jz8HPdRe_igoNjMSv0RaTn3l2c3seYFT?usp=sharing)
11 | ,使用ChatTTS把 `AnimateAnyone` 这篇文章变成播客。
12 |
13 |
14 |
15 | ### 新增功能
16 | - [x] 将ChatTTS的代码以我熟悉的方式重构。
17 | - [x] **使用TensorRT实现3倍以上的加速**, 在windows的3060显卡上从28token/s提升到110token/s。
18 | - [x] windows整合包,一键解压使用。
19 | - [x] 使用Lora等技术实现声音克隆。请参考 [声音克隆](assets/docs/voice_clone_ZH.md)
20 | - [ ] 使用剪枝、知识蒸馏等做模型压缩和加速,目标在移动端运行。
21 |
22 | ### 环境安装
23 | * 安装python3,推荐可以用[miniforge](https://github.com/conda-forge/miniforge).`conda create -n chattts_plus python=3.10 && conda activate chattts_plus`
24 | * 下载源码: `git clone https://github.com/warmshao/ChatTTSPlus`, 并到项目根目录下: `cd ChatTTSPlus`
25 | * 安装必要的python库, `pip install -r requirements.txt`
26 | * 【可选】如果你要使用tensorrt的话,请安装[tensorrt10](https://developer.nvidia.com/tensorrt/download)
27 | * 【windows用户推荐】直接从[Google Drive链接](https://drive.google.com/file/d/1yOnU5dRTJvFnc4wyw02nAeJH5_FgNod2/view?usp=sharing)下载整合包,解压后双击`webui.bat`即可使用。如果要更新代码的话,请先双击`update.bat`, 注意:**这会覆盖你本地所有的代码修改**。
28 |
29 | ### Demo
30 | * Webui with TensorRT: `python webui.py --cfg configs/infer/chattts_plus_trt.yaml`.
31 | * Webui with Pytorch: `python webui.py --cfg configs/infer/chattts_plus.yaml`
32 |
33 |
34 |
35 | ### License
36 | ChatTTSPlus继承[ChatTTS](https://github.com/2noise/ChatTTS)的license,请以[ChatTTS](https://github.com/2noise/ChatTTS)为标准。
37 |
38 | The code is published under AGPLv3+ license.
39 |
40 | The model is published under CC BY-NC 4.0 license. It is intended for educational and research use, and should not be used for any commercial or illegal purposes. The authors do not guarantee the accuracy, completeness, or reliability of the information. The information and data used in this repo, are for academic and research purposes only. The data obtained from publicly available sources, and the authors do not claim any ownership or copyright over the data.
41 |
42 | ### 免责声明
43 | 我们不对代码库的任何非法使用承担任何责任. 请参阅您当地关于 DMCA (数字千年法案) 和其他相关法律法规.
--------------------------------------------------------------------------------
/assets/docs/voice_clone.md:
--------------------------------------------------------------------------------
1 | ## ChatTTSPlus Voice Cloning
2 | Currently supports two methods of voice cloning:
3 | * Training with lora
4 | * Training speaker embedding
5 |
6 | Note:
7 | * The voice cloning feature of ChatTTSPlus is for learning purposes only. Please do not use it for illegal or criminal activities. We take no responsibility for any illegal use of the codebase.
8 | * Some code references: [ChatTTS PR](https://github.com/2noise/ChatTTS/pull/680)
9 |
10 | ### Data Collection and Preprocessing
11 | * Prepare over 30 minutes of audio from the person you want to clone.
12 | * Process using GPT-SoViTs preprocessing workflow, executing in sequence: audio splitting, UVR5 background separation, noise reduction, speech-to-text, etc.
13 | * Finally get a `.list` file in this format: speaker_name | audio_path | lang | text, like this:
14 | ```text
15 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000000000_0000152640.wav|EN|Hehe, I watched Parasite recently, really recommend it.
16 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000152640_0000323520.wav|EN|The plot is tight and the filming technique is very unique.
17 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000323520_0000474880.wav|EN|It won many awards too, what type of movies do you like?
18 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_2.WAV_10.wav_0000000000_0000114560.wav|EN|I like mystery films, any other recommendations?
19 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_3.WAV_10.wav_0000000000_0000133760.wav|EN|Woof woof woof, then you must watch And Then There Were None.
20 | ```
21 |
22 | ### Model Training
23 | #### Lora Training (Recommended)
24 | * Modify `DATA/meta_infos` in `configs/train/train_voice_clone_lora.yaml` to the `.list` file processed in previous step
25 | * Modify `exp_name` in `configs/train/train_voice_clone_lora.yaml` to the experiment name, preferably using speaker_name for identification.
26 | * Then run `accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml` to start training.
27 | * Trained models will be saved in the `outputs` folder, like `outputs/xionger_lora-1732809910.2932503/checkpoints/step-900`
28 | * You can visualize training logs using tensorboard, e.g., `tensorboad --logdir=outputs/xionger_lora-1732809910.2932503/tf_logs`
29 |
30 | #### Speaker Embedding Training (Not Recommended, Hard to Converge)
31 | * Modify `DATA/meta_infos` in `configs/train/train_speaker_embedding.yaml` to the `.list` file processed in previous step
32 | * Modify `exp_name` in `configs/train/train_speaker_embedding.yaml` to the experiment name, preferably using speaker_name for identification.
33 | * Then run `accelerate launch train_lora.py --config configs/train/train_speaker_embedding.yaml` to start training.
34 | * Trained speaker embeddings will be saved in the `outputs` folder, like `outputs/xionger_speaker_emb-1732931630.7137222/checkpoints/step-1/xionger.pt`
35 | * You can visualize training logs using tensorboard, e.g., `tensorboad --logdir=outputs/xionger_speaker_emb-1732931630.7137222/tf_logs`
36 |
37 | #### Some Tips
38 | * For better results, it's best to prepare more than 1 hour of audio. I tried training lora with 1 minute of audio, but it was prone to overfitting and the results were mediocre.
39 | * Don't train for too long, otherwise it can easily overfit. When I trained with 1 hour of Lei Jun's audio, it converged between 2000 to 3000 steps.
40 | * If you understand lora training, you can try adjusting the parameters in the config file.
41 |
42 | ### Model Inference
43 | * Launch webui: `python webui.py --cfg configs/infer/chattts_plus.yaml`
44 | * Refer to the following video tutorial for usage:
45 |
46 |
--------------------------------------------------------------------------------
/assets/docs/voice_clone_ZH.md:
--------------------------------------------------------------------------------
1 | ## ChatTTSPlus 声音克隆
2 | 目前支持两种声音克隆的方式:
3 | * 训练lora
4 | * 训练speaker embedding
5 |
6 | 注意:
7 | * ChatTTSPlus 的声音克隆功能仅供学习使用,请勿用于非法或犯罪活动。我们对代码库的任何非法使用不承担责任。
8 | * 部分代码参考: [ChatTTS PR](https://github.com/2noise/ChatTTS/pull/680)
9 |
10 |
11 | ### 数据收集和预处理
12 | * 准备想要克隆的某个人的30分钟以上的音频。
13 | * 使用GPT-SoViTs的预处理流程处理,依次执行:音频切分、UVR5背景声分离、语音降噪、语音识别文本等。
14 | * 最后得到这样的`.list`文件: speaker_nam | audio_path | lang | text,类似这样:
15 | ```text
16 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000000000_0000152640.wav|ZH|嘿嘿,最近我看了寄生虫,真的很推荐哦。
17 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000152640_0000323520.wav|ZH|这部电影剧情紧凑,拍摄手法也很独特。
18 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_1.WAV_10.wav_0000323520_0000474880.wav|ZH|还得了很多奖项,你有喜欢的电影类型吗?
19 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_2.WAV_10.wav_0000000000_0000114560.wav|ZH|我喜欢悬疑片,有其他推荐吗?
20 | xionger|E:\my_projects\ChatTTSPlus\data\xionger\slicer_opt\vocal_3.WAV_10.wav_0000000000_0000133760.wav|ZH|汪汪汪,那你一定要看无人生还。
21 | ```
22 |
23 | ### 模型训练:
24 | #### lora训练(推荐):
25 | * 修改`configs/train/train_voice_clone_lora.yaml`里 `DATA/meta_infos`为上一个处理到的`.list`文件
26 | * 修改`configs/train/train_voice_clone_lora.yaml`里 `exp_name`为实验的名称,最好用speaker_name做区分识别。
27 | * 然后运行`accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml`开始训练。
28 | * 训练的模型会保存在: `outputs` 文件夹下,比如`outputs/xionger_lora-1732809910.2932503/checkpoints/step-900`
29 | * 你可以使用tensorboard可视化训练的log, 比如`tensorboad --logdir=outputs/xionger_lora-1732809910.2932503/tf_logs`
30 |
31 | #### speaker embedding训练(不推荐,很难收敛):
32 | * 修改`configs/train/train_speaker_embedding.yaml`里 `DATA/meta_infos`为上一个处理到的`.list`文件
33 | * 修改`configs/train/train_speaker_embedding.yaml`里 `exp_name`为实验的名称,最好用speaker_name做区分识别。
34 | * 然后运行`accelerate launch train_lora.py --config configs/train/train_speaker_embedding.yaml`开始训练。
35 | * 训练的speaker embedding 会保存在: `outputs` 文件夹下,比如`outputs/xionger_speaker_emb-1732931630.7137222/checkpoints/step-1/xionger.pt`
36 | * 你可以使用tensorboard可视化训练的log, 比如`tensorboad --logdir=outputs/xionger_speaker_emb-1732931630.7137222/tf_logs`
37 |
38 | #### 一些Tips
39 | * 如果要效果好的话,最好准备1小时以上的音频。我试过1分钟训练lora,但是比较容易过拟合,效果一般。
40 | * 不要训练太久,不然容易过拟合,我用1小时的雷军音频训练,2000到3000 step就收敛了。
41 | * 如果你懂lora训练的话,你可以尝试调整config里面的参数。
42 |
43 | ### 模型推理
44 | * 启动webui: ` python webui.py --cfg configs/infer/chattts_plus.yaml`
45 | * 参考以下视频教程使用:
46 |
47 |
48 |
--------------------------------------------------------------------------------
/assets/speakers/2222.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/assets/speakers/2222.pt
--------------------------------------------------------------------------------
/chattts_plus/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/27 22:47
3 | # @Project : ChatTTSPlus
4 | # @FileName: __init__.py.py
5 |
--------------------------------------------------------------------------------
/chattts_plus/commons/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/9/22 11:48
3 | # @Project : ChatTTSPlus
4 | # @FileName: __init__.py.py
5 |
--------------------------------------------------------------------------------
/chattts_plus/commons/constants.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/10/8 10:37
3 | # @Author : wenshao
4 | # @ProjectName: ChatTTSPlus
5 | # @FileName: constants.py
6 |
7 | import os
8 |
9 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
10 | PROJECT_DIR = os.environ.get("CHATTTS_PLUS_PROJECT_DIR", os.path.join(CURRENT_DIR, '..', '..'))
11 | PROJECT_DIR = os.path.abspath(PROJECT_DIR)
12 | CHECKPOINT_DIR = os.environ.get("CHATTTS_PLUS_CHECKPOINT_DIR",
13 | os.path.abspath(os.path.join(CURRENT_DIR, '..', 'checkpoints')))
14 | CHECKPOINT_DIR = os.path.abspath(CHECKPOINT_DIR)
15 | LOG_DIR = os.environ.get("CHATTTS_PLUS_LOG_DIR", os.path.join(PROJECT_DIR, 'logs'))
16 | LOG_DIR = os.path.abspath(LOG_DIR)
17 |
--------------------------------------------------------------------------------
/chattts_plus/commons/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import logging
4 | from datetime import datetime, timezone
5 | from pathlib import Path
6 | from .constants import LOG_DIR
7 |
8 | logging.getLogger("numba").setLevel(logging.WARNING)
9 | logging.getLogger("httpx").setLevel(logging.WARNING)
10 | logging.getLogger("wetext-zh_normalizer").setLevel(logging.WARNING)
11 | logging.getLogger("NeMo-text-processing").setLevel(logging.WARNING)
12 |
13 | # from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96
14 | colorCodePanic = "\x1b[1;31m"
15 | colorCodeFatal = "\x1b[1;31m"
16 | colorCodeError = "\x1b[31m"
17 | colorCodeWarn = "\x1b[33m"
18 | colorCodeInfo = "\x1b[37m"
19 | colorCodeDebug = "\x1b[32m"
20 | colorCodeTrace = "\x1b[36m"
21 | colorReset = "\x1b[0m"
22 |
23 | log_level_color_code = {
24 | logging.DEBUG: colorCodeDebug,
25 | logging.INFO: colorCodeInfo,
26 | logging.WARN: colorCodeWarn,
27 | logging.ERROR: colorCodeError,
28 | logging.FATAL: colorCodeFatal,
29 | }
30 |
31 | log_level_msg_str = {
32 | logging.DEBUG: "DEBU",
33 | logging.INFO: "INFO",
34 | logging.WARN: "WARN",
35 | logging.ERROR: "ERRO",
36 | logging.FATAL: "FATL",
37 | }
38 |
39 |
40 | class Formatter(logging.Formatter):
41 | def __init__(self, color=platform.system().lower() != "windows"):
42 | # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone
43 | self.tz = datetime.now(timezone.utc).astimezone().tzinfo
44 | self.color = color
45 |
46 | def format(self, record: logging.LogRecord):
47 | logstr = "[" + datetime.now(self.tz).strftime("%Y%m%d %H:%M:%S") + "] ["
48 | if self.color:
49 | logstr += log_level_color_code.get(record.levelno, colorCodeInfo)
50 | logstr += log_level_msg_str.get(record.levelno, record.levelname)
51 | if self.color:
52 | logstr += colorReset
53 | fn = record.filename.removesuffix(".py")
54 | logstr += f"] {str(record.name)} | {fn} | {str(record.msg) % record.args}"
55 | return logstr
56 |
57 |
58 | def get_logger(name: str, lv=logging.INFO, remove_exist=False, format_root=False, log_file=None):
59 | """
60 | Configure and return a logger with specified settings.
61 |
62 | Args:
63 | name (str): The name of the logger.
64 | lv (int): The logging level (default: logging.INFO).
65 | remove_exist (bool): Whether to remove existing handlers (default: False).
66 | format_root (bool): Whether to format the root logger as well (default: False).
67 | log_file (str | Path | None): Path to the log file. If provided, logs will also be written to this file.
68 |
69 | Returns:
70 | logging.Logger: Configured logger instance.
71 | """
72 | logger = logging.getLogger(name)
73 | logger.setLevel(lv)
74 |
75 | if remove_exist and logger.hasHandlers():
76 | logger.handlers.clear()
77 |
78 | formatter = Formatter()
79 |
80 | # Add console handler if no handlers exist
81 | if not logger.hasHandlers():
82 | console_handler = logging.StreamHandler()
83 | console_handler.setFormatter(formatter)
84 | logger.addHandler(console_handler)
85 | else:
86 | # Update formatter for existing handlers
87 | for h in logger.handlers:
88 | h.setFormatter(formatter)
89 |
90 | # Add file handler if log_file is specified
91 | if log_file is None:
92 | os.makedirs(LOG_DIR, exist_ok=True)
93 | date_str = datetime.now().strftime("%y%m%d")
94 | log_file = os.path.join(LOG_DIR, f"chattts_plus_{date_str}.log")
95 |
96 | log_file = Path(log_file)
97 | # Create directory if it doesn't exist
98 | log_file.parent.mkdir(parents=True, exist_ok=True)
99 |
100 | file_handler = logging.FileHandler(str(log_file), encoding='utf-8')
101 | file_handler.setFormatter(formatter)
102 | logger.addHandler(file_handler)
103 |
104 | if format_root:
105 | for h in logger.root.handlers:
106 | h.setFormatter(formatter)
107 |
108 | return logger
109 |
--------------------------------------------------------------------------------
/chattts_plus/commons/norm.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import re
4 | from typing import Dict, Tuple, List, Literal, Callable, Optional
5 | import sys
6 |
7 | from numba import jit
8 | import numpy as np
9 | from . import logger as logger_
10 |
11 |
12 | @jit
13 | def _find_index(table: np.ndarray, val: np.uint16):
14 | for i in range(table.size):
15 | if table[i] == val:
16 | return i
17 | return -1
18 |
19 |
20 | @jit
21 | def _fast_replace(
22 | table: np.ndarray, text: bytes
23 | ) -> Tuple[np.ndarray, List[Tuple[str, str]]]:
24 | result = np.frombuffer(text, dtype=np.uint16).copy()
25 | replaced_words = []
26 | for i in range(result.size):
27 | ch = result[i]
28 | p = _find_index(table[0], ch)
29 | if p >= 0:
30 | repl_char = table[1][p]
31 | result[i] = repl_char
32 | replaced_words.append((chr(ch), chr(repl_char)))
33 | return result, replaced_words
34 |
35 |
36 | class Normalizer:
37 | def __init__(self, map_file_path: str, logger=None):
38 | if logger is None:
39 | logger = logger_.get_logger(self.__class__.__name__)
40 | self.logger = logger
41 | self.normalizers: Dict[str, Callable[[str], str]] = {}
42 | self.homophones_map = self._load_homophones_map(map_file_path)
43 | """
44 | homophones_map
45 |
46 | Replace the mispronounced characters with correctly pronounced ones.
47 |
48 | Creation process of homophones_map.json:
49 |
50 | 1. Establish a word corpus using the [Tencent AI Lab Embedding Corpora v0.2.0 large] with 12 million entries. After cleaning, approximately 1.8 million entries remain. Use ChatTTS to infer the text.
51 | 2. Record discrepancies between the inferred and input text, identifying about 180,000 misread words.
52 | 3. Create a pinyin to common characters mapping using correctly read characters by ChatTTS.
53 | 4. For each discrepancy, extract the correct pinyin using [python-pinyin] and find homophones with the correct pronunciation from the mapping.
54 |
55 | Thanks to:
56 | [Tencent AI Lab Embedding Corpora for Chinese and English Words and Phrases](https://ai.tencent.com/ailab/nlp/en/embedding.html)
57 | [python-pinyin](https://github.com/mozillazg/python-pinyin)
58 |
59 | """
60 | self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be"
61 | self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]")
62 | self.sub_pattern = re.compile(r"\[uv_break\]|\[laugh\]|\[lbreak\]")
63 | self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
64 | self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b")
65 | self.character_simplifier = str.maketrans(
66 | {
67 | ":": ",",
68 | ";": ",",
69 | "!": "。",
70 | "(": ",",
71 | ")": ",",
72 | "【": ",",
73 | "】": ",",
74 | "『": ",",
75 | "』": ",",
76 | "「": ",",
77 | "」": ",",
78 | "《": ",",
79 | "》": ",",
80 | "-": ",",
81 | ":": ",",
82 | ";": ",",
83 | "!": ".",
84 | "(": ",",
85 | ")": ",",
86 | # "[": ",",
87 | # "]": ",",
88 | ">": ",",
89 | "<": ",",
90 | "-": ",",
91 | }
92 | )
93 | self.halfwidth_2_fullwidth = str.maketrans(
94 | {
95 | "!": "!",
96 | '"': "“",
97 | "'": "‘",
98 | "#": "#",
99 | "$": "$",
100 | "%": "%",
101 | "&": "&",
102 | "(": "(",
103 | ")": ")",
104 | ",": ",",
105 | "-": "-",
106 | "*": "*",
107 | "+": "+",
108 | ".": "。",
109 | "/": "/",
110 | ":": ":",
111 | ";": ";",
112 | "<": "<",
113 | "=": "=",
114 | ">": ">",
115 | "?": "?",
116 | "@": "@",
117 | # '[': '[',
118 | "\\": "\",
119 | # ']': ']',
120 | "^": "^",
121 | # '_': '_',
122 | "`": "`",
123 | "{": "{",
124 | "|": "|",
125 | "}": "}",
126 | "~": "~",
127 | }
128 | )
129 |
130 | def __call__(
131 | self,
132 | text: str,
133 | do_text_normalization=True,
134 | do_homophone_replacement=True,
135 | lang: Optional[Literal["zh", "en"]] = None,
136 | ) -> str:
137 | if do_text_normalization:
138 | _lang = self._detect_language(text) if lang is None else lang
139 | if _lang in self.normalizers:
140 | text = self.normalizers[_lang](text)
141 | if _lang == "zh":
142 | text = self._apply_half2full_map(text)
143 | invalid_characters = self._count_invalid_characters(text)
144 | if len(invalid_characters):
145 | self.logger.debug(f"found invalid characters: {invalid_characters}")
146 | text = self._apply_character_map(text)
147 | if do_homophone_replacement:
148 | arr, replaced_words = _fast_replace(
149 | self.homophones_map,
150 | text.encode(self.coding),
151 | )
152 | if replaced_words:
153 | text = arr.tobytes().decode(self.coding)
154 | repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words])
155 | self.logger.debug(f"replace homophones: {repl_res}")
156 | if len(invalid_characters):
157 | text = self.reject_pattern.sub("", text)
158 | return text
159 |
160 | def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
161 | if name in self.normalizers:
162 | self.logger.warning(f"name {name} has been registered")
163 | return False
164 | try:
165 | val = normalizer("test string 测试字符串")
166 | if not isinstance(val, str):
167 | self.logger.warning("normalizer must have caller type (str) -> str")
168 | return False
169 | except Exception as e:
170 | self.logger.warning(e)
171 | return False
172 | self.normalizers[name] = normalizer
173 | return True
174 |
175 | def unregister(self, name: str):
176 | if name in self.normalizers:
177 | del self.normalizers[name]
178 |
179 | def destroy(self):
180 | del self.homophones_map
181 |
182 | def _load_homophones_map(self, map_file_path: str) -> np.ndarray:
183 | with open(map_file_path, "r", encoding="utf-8") as f:
184 | homophones_map: Dict[str, str] = json.load(f)
185 | map = np.empty((2, len(homophones_map)), dtype=np.uint32)
186 | for i, k in enumerate(homophones_map.keys()):
187 | map[:, i] = (ord(k), ord(homophones_map[k]))
188 | del homophones_map
189 | return map
190 |
191 | def _count_invalid_characters(self, s: str):
192 | s = self.sub_pattern.sub("", s)
193 | non_alphabetic_chinese_chars = self.reject_pattern.findall(s)
194 | return set(non_alphabetic_chinese_chars)
195 |
196 | def _apply_half2full_map(self, text: str) -> str:
197 | return text.translate(self.halfwidth_2_fullwidth)
198 |
199 | def _apply_character_map(self, text: str) -> str:
200 | return text.translate(self.character_simplifier)
201 |
202 | def _detect_language(self, sentence: str) -> Literal["zh", "en"]:
203 | chinese_chars = self.chinese_char_pattern.findall(sentence)
204 | english_words = self.english_word_pattern.findall(sentence)
205 |
206 | if len(chinese_chars) > len(english_words):
207 | return "zh"
208 | else:
209 | return "en"
210 |
--------------------------------------------------------------------------------
/chattts_plus/commons/onnx2trt.py:
--------------------------------------------------------------------------------
1 | #
2 | # SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3 | # SPDX-License-Identifier: Apache-2.0
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | #
17 |
18 | import os
19 | import pdb
20 | import sys
21 | import logging
22 | import argparse
23 |
24 | import tensorrt as trt
25 | from polygraphy.backend.trt import (
26 | engine_from_bytes,
27 | engine_from_network,
28 | network_from_onnx_path,
29 | save_engine,
30 | )
31 |
32 | logging.basicConfig(level=logging.INFO)
33 | logging.getLogger("EngineBuilder").setLevel(logging.INFO)
34 | log = logging.getLogger("EngineBuilder")
35 |
36 |
37 | class EngineBuilder:
38 | """
39 | Parses an ONNX graph and builds a TensorRT engine from it.
40 | """
41 |
42 | def __init__(self, verbose=False):
43 | """
44 | :param verbose: If enabled, a higher verbosity level will be set on the TensorRT logger.
45 | """
46 | self.trt_logger = trt.Logger(trt.Logger.INFO)
47 | if verbose:
48 | self.trt_logger.min_severity = trt.Logger.Severity.VERBOSE
49 |
50 | trt.init_libnvinfer_plugins(self.trt_logger, namespace="")
51 |
52 | self.builder = trt.Builder(self.trt_logger)
53 | self.config = self.builder.create_builder_config()
54 | # self.config.max_workspace_size = 24 * (2 ** 30) # 12 GB
55 |
56 | profile = self.builder.create_optimization_profile()
57 |
58 | # GPT
59 | profile.set_shape("inputs_embeds", min=[1, 1, 768], opt=[2, 1, 768], max=[4, 2048, 768])
60 | profile.set_shape("attention_mask", min=[1, 1], opt=[2, 256], max=[4, 2048])
61 | profile.set_shape("position_ids", min=[1, 1], opt=[2, 1], max=[4, 2048])
62 | profile.set_shape("past_key_values", min=[20, 2, 1, 12, 1, 64], opt=[20, 2, 2, 12, 256, 64],
63 | max=[20, 2, 4, 12, 2048, 64])
64 |
65 | self.config.add_optimization_profile(profile)
66 |
67 | self.batch_size = None
68 | self.network = None
69 | self.parser = None
70 |
71 | def create_network(self, onnx_path):
72 | """
73 | Parse the ONNX graph and create the corresponding TensorRT network definition.
74 | :param onnx_path: The path to the ONNX graph to load.
75 | """
76 |
77 | onnx_path = os.path.realpath(onnx_path)
78 | self.network_infos = self.builder, self.network, _ = network_from_onnx_path(
79 | onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
80 | )
81 |
82 | inputs = [self.network.get_input(i) for i in range(self.network.num_inputs)]
83 | outputs = [self.network.get_output(i) for i in range(self.network.num_outputs)]
84 | log.info("Network Description")
85 | for input in inputs:
86 | self.batch_size = input.shape[0]
87 | log.info("Input '{}' with shape {} and dtype {}".format(input.name, input.shape, input.dtype))
88 | for output in outputs:
89 | log.info("Output '{}' with shape {} and dtype {}".format(output.name, output.shape, output.dtype))
90 |
91 | def create_engine(
92 | self,
93 | engine_path,
94 | precision
95 | ):
96 | """
97 | Build the TensorRT engine and serialize it to disk.
98 | :param engine_path: The path where to serialize the engine to.
99 | :param precision: The datatype to use for the engine, either 'fp32', 'fp16' or 'int8'.
100 | """
101 | engine_path = os.path.realpath(engine_path)
102 | engine_dir = os.path.dirname(engine_path)
103 | os.makedirs(engine_dir, exist_ok=True)
104 | log.info("Building {} Engine in {}".format(precision, engine_path))
105 |
106 | if precision == "fp16":
107 | if not self.builder.platform_has_fast_fp16:
108 | log.warning("FP16 is not supported natively on this platform/device")
109 | else:
110 | self.config.set_flag(trt.BuilderFlag.FP16)
111 | elif precision == "int8":
112 | if not self.builder.platform_has_fast_int8:
113 | log.warning("Int8 is not supported natively on this platform/device")
114 | else:
115 | self.config.set_flag(trt.BuilderFlag.INT8)
116 |
117 | try:
118 | engine = engine_from_network(
119 | self.network_infos,
120 | self.config
121 | )
122 | except Exception as e:
123 | print(f"Failed to build engine: {e}")
124 | return 1
125 | try:
126 | save_engine(engine, path=engine_path)
127 | except Exception as e:
128 | print(f"Failed to save engine: {e}")
129 | return 1
130 | return 0
131 |
132 |
133 | def convert_onnx_to_trt(onnx_path, trt_path, verbose=False, precision="fp16"):
134 | builder = EngineBuilder(verbose)
135 | builder.create_network(onnx_path)
136 | builder.create_engine(
137 | trt_path,
138 | precision
139 | )
140 |
141 |
142 | if __name__ == "__main__":
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument("-o", "--onnx", required=True, help="The input ONNX model file to load")
145 | parser.add_argument("-e", "--engine", help="The output path for the TRT engine")
146 | parser.add_argument(
147 | "-p",
148 | "--precision",
149 | default="fp16",
150 | choices=["fp32", "fp16", "int8"],
151 | help="The precision mode to build in, either 'fp32', 'fp16' or 'int8', default: 'fp16'",
152 | )
153 | parser.add_argument("-v", "--verbose", action="store_true", help="Enable more verbose log output")
154 | args = parser.parse_args()
155 | if args.engine is None:
156 | args.engine = args.onnx.replace(".onnx", ".trt")
157 | convert_onnx_to_trt(args.onnx, args.engine, args.verbose, args.precision)
158 |
--------------------------------------------------------------------------------
/chattts_plus/commons/text_utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | # ref: https://github.com/PaddlePaddle/PaddleSpeech/tree/develop/paddlespeech/t2s/frontend/zh_normalization
3 | from zh_normalization import TextNormalizer
4 | from functools import partial
5 |
6 |
7 | # 数字转为英文读法
8 | def num_to_english(num):
9 | num_str = str(num)
10 | # English representations for numbers 0-9
11 | english_digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
12 | units = ["", "ten", "hundred", "thousand"]
13 | big_units = ["", "thousand", "million", "billion", "trillion"]
14 | result = ""
15 | need_and = False # Indicates whether 'and' needs to be added
16 | part = [] # Stores each group of 4 digits
17 | is_first_part = True # Indicates if it is the first part for not adding 'and' at the beginning
18 |
19 | # Split the number into 3-digit groups
20 | while num_str:
21 | part.append(num_str[-3:])
22 | num_str = num_str[:-3]
23 |
24 | part.reverse()
25 |
26 | for i, p in enumerate(part):
27 | p_str = ""
28 | digit_len = len(p)
29 | if int(p) == 0 and i < len(part) - 1:
30 | continue
31 |
32 | hundreds_digit = int(p) // 100 if digit_len == 3 else None
33 | tens_digit = int(p) % 100 if digit_len >= 2 else int(p[0] if digit_len == 1 else p[1])
34 |
35 | # Process hundreds
36 | if hundreds_digit is not None and hundreds_digit != 0:
37 | p_str += english_digits[hundreds_digit] + " hundred"
38 | if tens_digit != 0:
39 | p_str += " and "
40 |
41 | # Process tens and ones
42 | if 10 < tens_digit < 20: # Teens exception
43 | teen_map = {
44 | 11: "eleven", 12: "twelve", 13: "thirteen", 14: "fourteen", 15: "fifteen",
45 | 16: "sixteen", 17: "seventeen", 18: "eighteen", 19: "nineteen"
46 | }
47 | p_str += teen_map[tens_digit]
48 | else:
49 | tens_map = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
50 | tens_val = tens_digit // 10
51 | ones_val = tens_digit % 10
52 | if tens_val >= 2:
53 | p_str += tens_map[tens_val] + (" " + english_digits[ones_val] if ones_val != 0 else "")
54 | elif tens_digit != 0 and tens_val < 2: # When tens_digit is in [1, 9]
55 | p_str += english_digits[tens_digit]
56 |
57 | if p_str and not is_first_part and need_and:
58 | result += " and "
59 | result += p_str
60 | if i < len(part) - 1 and int(p) != 0:
61 | result += " " + big_units[len(part) - i - 1] + ", "
62 |
63 | is_first_part = False
64 | if int(p) != 0:
65 | need_and = True
66 |
67 | return result.capitalize()
68 |
69 |
70 | def get_lang(text):
71 | # 定义中文标点符号的模式
72 | chinese_punctuation = "[。?!,、;:‘’“”()《》【】…—\u3000]"
73 | # 使用正则表达式替换所有中文标点为""
74 | cleaned_text = re.sub(chinese_punctuation, "", text)
75 | # 使用正则表达式来匹配中文字符范围
76 | return "zh" if re.search('[\u4e00-\u9fff]', cleaned_text) is not None else "en"
77 |
78 |
79 | def fraction_to_words(match):
80 | numerator, denominator = match.groups()
81 | # 这里只是把数字直接拼接成了英文分数的形式, 实际上应该使用某种方式将数字转换为英文单词
82 | # 例如: "1/2" -> "one half", 这里仅为展示目的而直接返回了 "numerator/denominator"
83 | return numerator + " over " + denominator
84 |
85 |
86 | # 数字转为英文读法
87 | def num2text(text):
88 | numtext = [' zero ', ' one ', ' two ', ' three ', ' four ', ' five ', ' six ', ' seven ', ' eight ', ' nine ']
89 | point = ' point '
90 | text = re.sub(r'(\d)\,(\d)', r'\1\2', text)
91 | text = re.sub(r'(\d+)\s*\+', r'\1 plus ', text)
92 | text = re.sub(r'(\d+)\s*\-', r'\1 minus ', text)
93 | text = re.sub(r'(\d+)\s*[\*x]', r'\1 times ', text)
94 | text = re.sub(r'((?:\d+\.)?\d+)\s*/\s*(\d+)', fraction_to_words, text)
95 |
96 | # 取出数字 number_list= [('1000200030004000.123', '1000200030004000', '123'), ('23425', '23425', '')]
97 | number_list = re.findall(r'((\d+)(?:\.(\d+))?%?)', text)
98 | if len(number_list) > 0:
99 | # dc= ('1000200030004000.123', '1000200030004000', '123','')
100 | for m, dc in enumerate(number_list):
101 | if len(dc[1]) > 16:
102 | continue
103 | int_text = num_to_english(dc[1])
104 | if len(dc) > 2 and dc[2]:
105 | int_text += point + "".join([numtext[int(i)] for i in dc[2]])
106 | if dc[0][-1] == '%':
107 | int_text = f' the pronunciation of {int_text}'
108 | text = text.replace(dc[0], int_text)
109 |
110 | return text.replace('1', ' one ').replace('2', ' two ').replace('3', ' three ').replace('4', ' four ').replace('5',
111 | ' five ').replace(
112 | '6', ' six ').replace('7', 'seven').replace('8', ' eight ').replace('9', ' nine ').replace('0',
113 | ' zero ').replace(
114 | '=', ' equals ')
115 |
116 |
117 | def remove_brackets(text):
118 | # 正则表达式
119 | text = re.sub(r'\[(uv_break|laugh|lbreak|break)\]', r' \1 ', text, re.I | re.S | re.M)
120 |
121 | # 使用 re.sub 替换掉 [ ] 对
122 | newt = re.sub(r'\[|\]|!|:|{|}', '', text)
123 | return re.sub(r'\s(uv_break|laugh|lbreak|break)(?=\s|$)', r' [\1] ', newt)
124 |
125 |
126 | # 中英文数字转换为文字,特殊符号处理
127 | def split_text(text_list):
128 | tx = TextNormalizer()
129 | haserror = False
130 | result = []
131 | for i, text in enumerate(text_list):
132 | text = remove_brackets(text)
133 | if get_lang(text) == 'zh':
134 | tmp = "".join(tx.normalize(text))
135 | elif haserror:
136 | tmp = num2text(text)
137 | else:
138 | try:
139 | # 先尝试使用 nemo_text_processing 处理英文
140 | from nemo_text_processing.text_normalization.normalize import Normalizer
141 | fun = partial(Normalizer(input_case='cased', lang="en").normalize, verbose=False,
142 | punct_post_process=True)
143 | tmp = fun(text)
144 | print(f'使用nemo处理英文ok')
145 | except Exception as e:
146 | print(f"nemo处理英文失败,改用自定义预处理")
147 | print(e)
148 | haserror = True
149 | tmp = num2text(text)
150 |
151 | if len(tmp) > 200:
152 | tmp_res = split_text_by_punctuation(tmp)
153 | result = result + tmp_res
154 | else:
155 | result.append(tmp)
156 | return result
157 |
158 |
159 | # 切分长行 200 150
160 | def split_text_by_punctuation(text):
161 | # 定义长度限制
162 | min_length = 150
163 | punctuation_marks = "。?!,、;:”’》」』)】…—"
164 | english_punctuation = ".?!,:;)}…"
165 |
166 | # 结果列表
167 | result = []
168 | # 起始位置
169 | pos = 0
170 |
171 | # 遍历文本中的每个字符
172 | text_length = len(text)
173 | for i, char in enumerate(text):
174 | if char in punctuation_marks or char in english_punctuation:
175 | if char == '.' and i < text_length - 1 and re.match(r'\d', text[i + 1]):
176 | continue
177 | # 当遇到标点时,判断当前分段长度是否超过120
178 | if i - pos > min_length:
179 | # 如果长度超过120,将当前分段添加到结果列表中
180 | result.append(text[pos:i + 1])
181 | # 更新起始位置到当前标点的下一个字符
182 | pos = i + 1
183 | # print(f'{pos=},{len(text)=}')
184 |
185 | # 如果剩余文本长度超过120或没有更多标点符号可以进行分割,将剩余的文本作为一个分段添加到结果列表
186 | if pos < len(text):
187 | result.append(text[pos:])
188 |
189 | return result
190 |
--------------------------------------------------------------------------------
/chattts_plus/commons/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/10/8 10:36
3 | # @Author : wenshao
4 | # @ProjectName: ChatTTSPlus
5 | # @FileName: utils.py
6 |
7 | import torch
8 | from dataclasses import dataclass, asdict
9 | from typing import Literal, Optional, List, Tuple, Dict, Union
10 |
11 |
12 | @dataclass(repr=False, eq=False)
13 | class RefineTextParams:
14 | prompt: str = ""
15 | top_P: float = 0.7
16 | top_K: int = 20
17 | temperature: float = 0.7
18 | repetition_penalty: float = 1.0
19 | max_new_token: int = 384
20 | min_new_token: int = 0
21 | show_tqdm: bool = True
22 | ensure_non_empty: bool = True
23 |
24 |
25 | @dataclass(repr=False, eq=False)
26 | class InferCodeParams(RefineTextParams):
27 | prompt: str = "[speed_5]"
28 | spk_emb: Optional[str] = None
29 | spk_smp: Optional[str] = None
30 | txt_smp: Optional[str] = None
31 | temperature: float = 0.3
32 | repetition_penalty: float = 1.05
33 | max_new_token: int = 2048
34 | stream_batch: int = 24
35 | stream_speed: int = 12000
36 | pass_first_n_batches: int = 2
37 |
38 |
39 | def get_inference_device():
40 | if torch.cuda.is_available():
41 | return torch.device("cuda")
42 | elif torch.backends.mps.is_available():
43 | return torch.device("mps")
44 | else:
45 | return torch.device("cpu")
46 |
47 |
48 | class TorchSeedContext:
49 | def __init__(self, seed):
50 | self.seed = seed
51 | self.state = None
52 |
53 | def __enter__(self):
54 | self.state = torch.random.get_rng_state()
55 | torch.manual_seed(self.seed)
56 |
57 | def __exit__(self, type, value, traceback):
58 | torch.random.set_rng_state(self.state)
59 |
--------------------------------------------------------------------------------
/chattts_plus/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/10/7 12:14
3 | # @Author : wenshao
4 | # @ProjectName: ChatTTSPlus
5 | # @FileName: __init__.py.py
6 |
--------------------------------------------------------------------------------
/chattts_plus/datasets/base_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/23
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: base_dataset.py
7 |
8 | import os.path
9 | import pdb
10 |
11 | import torch
12 | import torchaudio
13 | from torch.utils.data import Dataset
14 |
15 |
16 | class BaseDataset(Dataset):
17 | """
18 | 设置一个通用的 train dataset loader
19 | """
20 |
21 | def __init__(self, meta_infos=[],
22 | tokenizer=None,
23 | normalizer=None,
24 | sample_rate=24_000,
25 | num_vq=4,
26 | use_empty_speaker=False,
27 | **kwargs
28 | ):
29 | super(BaseDataset, self).__init__()
30 | self.meta_infos = meta_infos
31 | self.tokenizer = tokenizer
32 | self.normalizer = normalizer
33 | self.sample_rate = sample_rate
34 | self.num_vq = num_vq
35 | self.use_empty_speaker = use_empty_speaker
36 | self.data_infos, self.speakers = self.load_data()
37 |
38 | def load_data(self, **kwargs):
39 | data_infos = []
40 | speakers = set()
41 | for info_path in self.meta_infos:
42 | try:
43 | with open(info_path, "r", encoding='UTF-8') as fin:
44 | for line_num, line in enumerate(fin.readlines(), 1):
45 | line = line.strip().replace('\n', '')
46 | if not line: # 跳过空行
47 | continue
48 |
49 | line_splits = line.split("|")
50 | if len(line_splits) < 4: # 确保有足够的字段
51 | print(f"警告: 第 {line_num} 行格式不正确: {line}")
52 | continue
53 |
54 | try:
55 | data_info = {
56 | "speaker": line_splits[0],
57 | "audio_path": line_splits[1],
58 | "text": line_splits[-1],
59 | "lang": line_splits[-2].lower()
60 | }
61 | # 验证音频文件是否存在
62 | if not os.path.exists(data_info["audio_path"]):
63 | print(f"警告: 音频文件不存在: {data_info['audio_path']}")
64 | continue
65 |
66 | speakers.add(data_info["speaker"])
67 | data_infos.append(data_info)
68 | except Exception as e:
69 | print(f"错误: 处理第 {line_num} 行时出错: {str(e)}")
70 | print(f"行内容: {line}")
71 | continue
72 |
73 | except Exception as e:
74 | print(f"错误: 读取文件 {info_path} 时出错: {str(e)}")
75 | continue
76 |
77 | if not data_infos:
78 | raise ValueError(f"没有找到有效的训练数据。请检查数据文件: {self.meta_infos}")
79 |
80 | return data_infos, speakers
81 |
82 | def __len__(self):
83 | return len(self.data_infos)
84 |
85 | def __getitem__(self, i):
86 | data_info_ = self.data_infos[i]
87 | audio_wavs, audio_mask = self.preprocess_audio(data_info_["audio_path"])
88 | text_input_ids, text_mask = self.preprocess_text(data_info_["text"], data_info_["lang"])
89 | return {
90 | "speaker": data_info_["speaker"],
91 | "text": data_info_["text"],
92 | "audio_wavs": audio_wavs,
93 | "audio_mask": audio_mask,
94 | "text_input_ids": text_input_ids,
95 | "text_mask": text_mask
96 | }
97 |
98 | def preprocess_text(
99 | self,
100 | text,
101 | lang="zh",
102 | do_text_normalization=True,
103 | do_homophone_replacement=True,
104 | ):
105 |
106 | text = self.normalizer(
107 | text,
108 | do_text_normalization,
109 | do_homophone_replacement,
110 | lang,
111 | )
112 | if self.use_empty_speaker:
113 | text = f'[Stts][empty_spk]{text}[Ptts]'
114 | else:
115 | text = f'[Stts][spk_emb]{text}[Ptts]'
116 | input_ids, attention_mask, text_mask = self.tokenizer.encode([text], num_vq=self.num_vq)
117 | return input_ids.squeeze(0), text_mask.squeeze(0)
118 |
119 | def preprocess_audio(self, audio_path):
120 | # 如果是相对路径,转换为绝对路径
121 | if not os.path.isabs(audio_path):
122 | audio_path = os.path.join(os.path.dirname(self.meta_info), audio_path)
123 |
124 | # 确保文件存在
125 | if not os.path.exists(audio_path):
126 | raise FileNotFoundError(f"Audio file not found: {audio_path}")
127 |
128 | audio_wavs, sample_rate = torchaudio.load(audio_path)
129 | if sample_rate != self.sample_rate:
130 | audio_wavs = torchaudio.functional.resample(
131 | audio_wavs,
132 | orig_freq=sample_rate,
133 | new_freq=self.sample_rate,
134 | )
135 | audio_wavs = audio_wavs.mean(0)
136 | audio_mask = torch.ones(len(audio_wavs))
137 | return audio_wavs, audio_mask
138 |
--------------------------------------------------------------------------------
/chattts_plus/datasets/collator.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/23
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: collator.py
7 | import pdb
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | class BaseCollator:
14 | def __init__(self, text_pad: int = 0, audio_pad: int = 0):
15 | self.text_pad = text_pad
16 | self.audio_pad = audio_pad
17 |
18 | def __call__(self, batch):
19 | batch = [x for x in batch if x is not None]
20 |
21 | audio_maxlen = max(len(item["audio_wavs"]) for item in batch)
22 | text_maxlen = max(len(item["text_input_ids"]) for item in batch)
23 |
24 | speaker = []
25 | text = []
26 | text_input_ids = []
27 | text_mask = []
28 | audio_wavs = []
29 | audio_mask = []
30 | for x in batch:
31 | text.append(x["text"])
32 | speaker.append(x["speaker"])
33 | text_input_ids.append(
34 | F.pad(
35 | x["text_input_ids"],
36 | (0, 0, text_maxlen - len(x["text_input_ids"]), 0),
37 | value=self.text_pad,
38 | )
39 | )
40 | text_mask.append(
41 | F.pad(
42 | x["text_mask"],
43 | (text_maxlen - len(x["text_mask"]), 0),
44 | value=0,
45 | )
46 | )
47 | audio_wavs.append(
48 | F.pad(
49 | x["audio_wavs"],
50 | (0, audio_maxlen - len(x["audio_wavs"])),
51 | value=self.audio_pad,
52 | )
53 | )
54 | audio_mask.append(
55 | F.pad(
56 | x["audio_mask"],
57 | (0, audio_maxlen - len(x["audio_mask"])),
58 | value=0,
59 | )
60 | )
61 | return {
62 | "speaker": speaker,
63 | "text": text,
64 | "text_input_ids": torch.stack(text_input_ids),
65 | "text_mask": torch.stack(text_mask),
66 | "audio_wavs": torch.stack(audio_wavs),
67 | "audio_mask": torch.stack(audio_mask),
68 | }
69 |
--------------------------------------------------------------------------------
/chattts_plus/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/27 22:53
3 | # @Project : ChatTTSPlus
4 | # @FileName: __init__.py.py
5 |
6 | from .tokenizer import Tokenizer
7 | from .gpt import GPT
8 | from .dvae import DVAE
9 |
--------------------------------------------------------------------------------
/chattts_plus/models/dvae.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pdb
3 | from typing import List, Optional, Literal, Tuple
4 |
5 | import numpy as np
6 | import pybase16384 as b14
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchaudio
11 | from vector_quantize_pytorch import GroupedResidualFSQ
12 |
13 | from ..commons import logger
14 |
15 |
16 | class ConvNeXtBlock(nn.Module):
17 | def __init__(
18 | self,
19 | dim: int,
20 | intermediate_dim: int,
21 | kernel: int,
22 | dilation: int,
23 | layer_scale_init_value: float = 1e-6,
24 | ):
25 | # ConvNeXt Block copied from Vocos.
26 | super().__init__()
27 | self.dwconv = nn.Conv1d(
28 | dim,
29 | dim,
30 | kernel_size=kernel,
31 | padding=dilation * (kernel // 2),
32 | dilation=dilation,
33 | groups=dim,
34 | ) # depthwise conv
35 |
36 | self.norm = nn.LayerNorm(dim, eps=1e-6)
37 | self.pwconv1 = nn.Linear(
38 | dim, intermediate_dim
39 | ) # pointwise/1x1 convs, implemented with linear layers
40 | self.act = nn.GELU()
41 | self.pwconv2 = nn.Linear(intermediate_dim, dim)
42 | self.gamma = (
43 | nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
44 | if layer_scale_init_value > 0
45 | else None
46 | )
47 |
48 | def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
49 | residual = x
50 |
51 | y = self.dwconv(x)
52 | y.transpose_(1, 2) # (B, C, T) -> (B, T, C)
53 | x = self.norm(y)
54 | y = self.pwconv1(x)
55 | x = self.act(y)
56 | y = self.pwconv2(x)
57 | if self.gamma is not None:
58 | y *= self.gamma
59 | y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
60 |
61 | x = y + residual
62 |
63 | return x
64 |
65 |
66 | class GFSQ(nn.Module):
67 |
68 | def __init__(
69 | self, dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose=True
70 | ):
71 | super(GFSQ, self).__init__()
72 | self.quantizer = GroupedResidualFSQ(
73 | dim=dim,
74 | levels=list(levels),
75 | num_quantizers=R,
76 | groups=G,
77 | )
78 | self.n_ind = math.prod(levels)
79 | self.eps = eps
80 | self.transpose = transpose
81 | self.G = G
82 | self.R = R
83 |
84 | def _embed(self, x: torch.Tensor):
85 | if self.transpose:
86 | x = x.transpose(1, 2)
87 | """
88 | x = rearrange(
89 | x, "b t (g r) -> g b t r", g = self.G, r = self.R,
90 | )
91 | """
92 | x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3)
93 | feat = self.quantizer.get_output_from_indices(x)
94 | return feat.transpose_(1, 2) if self.transpose else feat
95 |
96 | def __call__(self, x: torch.Tensor) -> torch.Tensor:
97 | return super().__call__(x)
98 |
99 | def forward(self, x: torch.Tensor) -> torch.Tensor:
100 | if self.transpose:
101 | x.transpose_(1, 2)
102 | # feat, ind = self.quantizer(x)
103 | with torch.autocast(device_type=str(x.device), dtype=torch.float32):
104 | _, ind = self.quantizer(x)
105 | """
106 | ind = rearrange(
107 | ind, "g b t r ->b t (g r)",
108 | )
109 | """
110 | ind = ind.permute(1, 2, 0, 3).contiguous()
111 | ind = ind.view(ind.size(0), ind.size(1), -1)
112 | """
113 | embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
114 | embed_onehot = embed_onehot_tmp.to(x.dtype)
115 | del embed_onehot_tmp
116 | e_mean = torch.mean(embed_onehot, dim=[0, 1])
117 | # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
118 | torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean)
119 | perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
120 |
121 | return
122 | torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device),
123 | feat.transpose_(1, 2) if self.transpose else feat,
124 | perplexity,
125 | """
126 | return ind.transpose_(1, 2) if self.transpose else ind
127 |
128 |
129 | class DVAEDecoder(nn.Module):
130 | def __init__(
131 | self,
132 | idim: int,
133 | odim: int,
134 | n_layer=12,
135 | bn_dim=64,
136 | hidden=256,
137 | kernel=7,
138 | dilation=2,
139 | up=False,
140 | ):
141 | super().__init__()
142 | self.up = up
143 | self.conv_in = nn.Sequential(
144 | nn.Conv1d(idim, bn_dim, 3, 1, 1),
145 | nn.GELU(),
146 | nn.Conv1d(bn_dim, hidden, 3, 1, 1),
147 | )
148 | self.decoder_block = nn.ModuleList(
149 | [
150 | ConvNeXtBlock(
151 | hidden,
152 | hidden * 4,
153 | kernel,
154 | dilation,
155 | )
156 | for _ in range(n_layer)
157 | ]
158 | )
159 | self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)
160 |
161 | def forward(self, x: torch.Tensor, conditioning=None) -> torch.Tensor:
162 | # B, C, T
163 | y = self.conv_in(x)
164 | for f in self.decoder_block:
165 | y = f(y, conditioning)
166 |
167 | x = self.conv_out(y)
168 | return x
169 |
170 |
171 | class MelSpectrogramFeatures(torch.nn.Module):
172 | def __init__(
173 | self,
174 | sample_rate=24000,
175 | n_fft=1024,
176 | hop_length=256,
177 | n_mels=100,
178 | padding: Literal["center", "same"] = "center",
179 | ):
180 | super().__init__()
181 | if padding not in ["center", "same"]:
182 | raise ValueError("Padding must be 'center' or 'same'.")
183 | self.padding = padding
184 | self.mel_spec = torchaudio.transforms.MelSpectrogram(
185 | sample_rate=sample_rate,
186 | n_fft=n_fft,
187 | hop_length=hop_length,
188 | n_mels=n_mels,
189 | center=padding == "center",
190 | power=1,
191 | )
192 |
193 | def __call__(self, audio: torch.Tensor) -> torch.Tensor:
194 | return super().__call__(audio)
195 |
196 | def forward(self, audio: torch.Tensor) -> torch.Tensor:
197 | mel: torch.Tensor = self.mel_spec(audio)
198 | features = torch.log(torch.clip(mel, min=1e-5))
199 | return features
200 |
201 |
202 | class DVAE(nn.Module):
203 | def __init__(
204 | self,
205 | decoder_config: dict,
206 | encoder_config: Optional[dict] = None,
207 | vq_config: Optional[dict] = None,
208 | dim=512,
209 | coef: Optional[str] = None,
210 | **kwargs
211 | ):
212 | super().__init__()
213 | self.logger = logger.get_logger(self.__class__.__name__)
214 |
215 | if coef is None:
216 | coef = torch.rand(100)
217 | else:
218 | coef = torch.from_numpy(
219 | np.copy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32))
220 | )
221 |
222 | self.register_buffer("coef", coef.unsqueeze(0).unsqueeze_(2))
223 |
224 | if encoder_config is not None:
225 | self.downsample_conv = nn.Sequential(
226 | nn.Conv1d(100, dim, 3, 1, 1),
227 | nn.GELU(),
228 | nn.Conv1d(dim, dim, 4, 2, 1),
229 | nn.GELU(),
230 | )
231 | self.preprocessor_mel = MelSpectrogramFeatures()
232 | self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config)
233 |
234 | self.decoder = DVAEDecoder(**decoder_config)
235 | self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
236 | if vq_config is not None:
237 | self.vq_layer = GFSQ(**vq_config)
238 | else:
239 | self.vq_layer = None
240 |
241 | self.model_path = kwargs.get("model_path", None)
242 | if self.model_path:
243 | self.logger.info(f"loading DVAE pretrained model: {self.model_path}")
244 | self.from_pretrained(self.model_path)
245 |
246 | def from_pretrained(self, file_path: str):
247 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True))
248 |
249 | def __repr__(self) -> str:
250 | return b14.encode_to_string(
251 | self.coef.cpu().numpy().astype(np.float32).tobytes()
252 | )
253 |
254 | def __call__(
255 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
256 | ) -> torch.Tensor:
257 | return super().__call__(inp, mode)
258 |
259 | @torch.inference_mode()
260 | def forward(
261 | self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"
262 | ) -> torch.Tensor:
263 | if mode == "encode" and hasattr(self, "encoder") and self.vq_layer is not None:
264 | mel = self.preprocessor_mel(inp)
265 | x: torch.Tensor = self.downsample_conv(
266 | torch.div(mel, self.coef.view(1, 100, 1).expand(mel.shape), out=mel),
267 | )
268 | x = self.encoder(x)
269 | ind = self.vq_layer(x)
270 | return ind
271 |
272 | if self.vq_layer is not None:
273 | vq_feats = self.vq_layer._embed(inp)
274 | else:
275 | vq_feats = inp
276 |
277 | vq_feats = (
278 | vq_feats.view(
279 | (vq_feats.size(0), 2, vq_feats.size(1) // 2, vq_feats.size(2)),
280 | )
281 | .permute(0, 2, 3, 1)
282 | .flatten(2)
283 | )
284 |
285 | dec_out = self.out_conv(
286 | self.decoder(
287 | x=vq_feats,
288 | ),
289 | )
290 |
291 | return torch.mul(dec_out, self.coef, out=dec_out)
292 |
--------------------------------------------------------------------------------
/chattts_plus/models/gpt.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import platform
3 | from dataclasses import dataclass
4 | import logging
5 | from typing import Union, List, Optional, Tuple
6 | import gc
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.utils.parametrize as P
12 | from torch.nn.utils.parametrizations import weight_norm
13 | from tqdm import tqdm
14 | from transformers import LlamaConfig, LogitsWarper
15 | from transformers.cache_utils import Cache
16 | from transformers.modeling_outputs import BaseModelOutputWithPast
17 | from transformers.utils import is_flash_attn_2_available
18 |
19 | from .llama import LlamaModel
20 | from .processors import CustomRepetitionPenaltyLogitsProcessorRepeat
21 | from ..commons import logger
22 |
23 |
24 | class GPT(nn.Module):
25 | def __init__(
26 | self,
27 | gpt_config: dict,
28 | num_audio_tokens: int = 626,
29 | num_text_tokens: int = 21178,
30 | num_vq=4,
31 | use_flash_attn=False,
32 | **kwargs
33 | ):
34 | super().__init__()
35 | self.logger = logger.get_logger(self.__class__.__name__)
36 | self.num_vq = num_vq
37 | self.num_audio_tokens = num_audio_tokens
38 |
39 | self.use_flash_attn = use_flash_attn
40 |
41 | self.gpt, self.llama_config = self._build_llama(gpt_config)
42 | self.is_te_llama = False
43 | self.model_dim = int(self.gpt.config.hidden_size)
44 | self.emb_code = nn.ModuleList(
45 | [
46 | nn.Embedding(
47 | num_audio_tokens,
48 | self.model_dim
49 | )
50 | for _ in range(num_vq)
51 | ],
52 | )
53 | self.emb_text = nn.Embedding(
54 | num_text_tokens, self.model_dim
55 | )
56 |
57 | self.head_text = weight_norm(
58 | nn.Linear(
59 | self.model_dim,
60 | num_text_tokens,
61 | bias=False,
62 | ),
63 | name="weight",
64 | )
65 | self.head_code = nn.ModuleList(
66 | [
67 | weight_norm(
68 | nn.Linear(
69 | self.model_dim,
70 | num_audio_tokens,
71 | bias=False,
72 | ),
73 | name="weight",
74 | )
75 | for _ in range(self.num_vq)
76 | ],
77 | )
78 |
79 | self.model_path = kwargs.get("model_path", None)
80 | if self.model_path:
81 | self.logger.info(f"loading GPT pretrained model: {self.model_path}")
82 | self.from_pretrained(self.model_path)
83 |
84 | def from_pretrained(self, file_path: str):
85 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True))
86 |
87 | class Context:
88 | def __init__(self):
89 | self._interrupt = False
90 |
91 | def set(self, v: bool):
92 | self._interrupt = v
93 |
94 | def get(self) -> bool:
95 | return self._interrupt
96 |
97 | def _build_llama(
98 | self,
99 | config: dict,
100 | ) -> Tuple[LlamaModel, LlamaConfig]:
101 |
102 | if self.use_flash_attn and is_flash_attn_2_available():
103 | llama_config = LlamaConfig(
104 | **config,
105 | attn_implementation="flash_attention_2",
106 | )
107 | self.logger.info(
108 | "enabling flash_attention_2 may make gpt be even slower"
109 | )
110 | else:
111 | llama_config = LlamaConfig(**config)
112 |
113 | model = LlamaModel(llama_config)
114 | del model.embed_tokens
115 | return model, llama_config
116 |
117 | def __call__(
118 | self, input_ids: torch.Tensor, text_mask: torch.Tensor
119 | ) -> torch.Tensor:
120 | """
121 | get_emb
122 | """
123 | return super().__call__(input_ids, text_mask)
124 |
125 | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
126 | """
127 | get_emb
128 | """
129 |
130 | emb_text: torch.Tensor = self.emb_text(
131 | input_ids[text_mask].narrow(1, 0, 1).squeeze_(1)
132 | )
133 |
134 | text_mask_inv = text_mask.logical_not()
135 | masked_input_ids: torch.Tensor = input_ids[text_mask_inv]
136 |
137 | emb_code = [
138 | self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
139 | ]
140 | emb_code = torch.stack(emb_code, 2).sum(2)
141 |
142 | emb = torch.zeros(
143 | (input_ids.shape[:-1]) + (emb_text.shape[-1],),
144 | device=emb_text.device,
145 | dtype=emb_text.dtype,
146 | )
147 | emb[text_mask] = emb_text
148 | emb[text_mask_inv] = emb_code.to(emb.dtype)
149 | return emb
150 |
151 | @dataclass(repr=False, eq=False)
152 | class _GenerationInputs:
153 | position_ids: torch.Tensor
154 | cache_position: torch.Tensor
155 | use_cache: bool
156 | input_ids: Optional[torch.Tensor] = None
157 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
158 | attention_mask: Optional[torch.Tensor] = None
159 | inputs_embeds: Optional[torch.Tensor] = None
160 |
161 | def to(self, device: torch.device, dtype: torch.dtype):
162 | if self.attention_mask is not None:
163 | self.attention_mask = self.attention_mask.to(device, dtype=dtype)
164 | if self.position_ids is not None:
165 | self.position_ids = self.position_ids.to(device, dtype=dtype)
166 | if self.inputs_embeds is not None:
167 | self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
168 | if self.cache_position is not None:
169 | self.cache_position = self.cache_position.to(device, dtype=dtype)
170 |
171 | def _prepare_generation_inputs(
172 | self,
173 | input_ids: torch.Tensor,
174 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
175 | attention_mask: Optional[torch.Tensor] = None,
176 | inputs_embeds: Optional[torch.Tensor] = None,
177 | cache_position: Optional[torch.Tensor] = None,
178 | position_ids: Optional[torch.Tensor] = None,
179 | use_cache=True,
180 | ) -> _GenerationInputs:
181 | # With static cache, the `past_key_values` is None
182 | # TODO joao: standardize interface for the different Cache classes and remove of this if
183 | has_static_cache = False
184 | if past_key_values is None:
185 | if hasattr(self.gpt.layers[0], "self_attn"):
186 | past_key_values = getattr(
187 | self.gpt.layers[0].self_attn, "past_key_value", None
188 | )
189 | has_static_cache = past_key_values is not None
190 |
191 | past_length = 0
192 | if past_key_values is not None:
193 | if isinstance(past_key_values, Cache):
194 | past_length = (
195 | int(cache_position[0])
196 | if cache_position is not None
197 | else past_key_values.get_seq_length()
198 | )
199 | max_cache_length = past_key_values.get_max_length()
200 | cache_length = (
201 | past_length
202 | if max_cache_length is None
203 | else min(max_cache_length, past_length)
204 | )
205 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
206 | else:
207 | cache_length = past_length = past_key_values[0][0].shape[2]
208 | max_cache_length = None
209 |
210 | # Keep only the unprocessed tokens:
211 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
212 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
213 | # input)
214 | if (
215 | attention_mask is not None
216 | and attention_mask.shape[1] > input_ids.shape[1]
217 | ):
218 | start = attention_mask.shape[1] - past_length
219 | input_ids = input_ids.narrow(1, -start, start)
220 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
221 | # input_ids based on the past_length.
222 | elif past_length < input_ids.shape[1]:
223 | input_ids = input_ids.narrow(
224 | 1, past_length, input_ids.size(1) - past_length
225 | )
226 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
227 |
228 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
229 | if (
230 | max_cache_length is not None
231 | and attention_mask is not None
232 | and cache_length + input_ids.shape[1] > max_cache_length
233 | ):
234 | attention_mask = attention_mask.narrow(
235 | 1, -max_cache_length, max_cache_length
236 | )
237 |
238 | if attention_mask is not None and position_ids is None:
239 | # create position_ids on the fly for batch generation
240 | position_ids = attention_mask.long().cumsum(-1) - 1
241 | position_ids.masked_fill_(attention_mask.eq(0), 1)
242 | if past_key_values:
243 | position_ids = position_ids.narrow(
244 | 1, -input_ids.shape[1], input_ids.shape[1]
245 | )
246 |
247 | input_length = (
248 | position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
249 | )
250 | if cache_position is None:
251 | cache_position = torch.arange(
252 | past_length, past_length + input_length, device=input_ids.device
253 | )
254 | else:
255 | cache_position = cache_position.narrow(0, -input_length, input_length)
256 |
257 | if has_static_cache:
258 | past_key_values = None
259 |
260 | model_inputs = self._GenerationInputs(
261 | position_ids=position_ids,
262 | cache_position=cache_position,
263 | use_cache=use_cache,
264 | )
265 |
266 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
267 | if inputs_embeds is not None and past_key_values is None:
268 | model_inputs.inputs_embeds = inputs_embeds
269 | else:
270 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
271 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
272 | # TODO: use `next_tokens` directly instead.
273 | model_inputs.input_ids = input_ids.contiguous()
274 |
275 | model_inputs.past_key_values = past_key_values
276 | model_inputs.attention_mask = attention_mask
277 |
278 | return model_inputs
279 |
280 | @dataclass(repr=False, eq=False)
281 | class GenerationOutputs:
282 | ids: List[torch.Tensor]
283 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
284 | hiddens: List[torch.Tensor]
285 |
286 | def _prepare_generation_outputs(
287 | self,
288 | inputs_ids: torch.Tensor,
289 | start_idx: int,
290 | end_idx: torch.Tensor,
291 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
292 | hiddens: List[torch.Tensor],
293 | infer_text: bool,
294 | ) -> GenerationOutputs:
295 | inputs_ids = [
296 | inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
297 | ]
298 | if infer_text:
299 | inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
300 |
301 | if len(hiddens) > 0:
302 | hiddens = torch.stack(hiddens, 1)
303 | hiddens = [
304 | hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
305 | ]
306 |
307 | return self.GenerationOutputs(
308 | ids=inputs_ids,
309 | attentions=attentions,
310 | hiddens=hiddens,
311 | )
312 |
313 | @torch.no_grad()
314 | def generate(
315 | self,
316 | emb: torch.Tensor,
317 | inputs_ids: torch.Tensor,
318 | temperature: torch.Tensor,
319 | eos_token: Union[int, torch.Tensor],
320 | attention_mask: Optional[torch.Tensor] = None,
321 | max_new_token=2048,
322 | min_new_token=0,
323 | logits_warpers: List[LogitsWarper] = [],
324 | logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
325 | infer_text=False,
326 | return_attn=False,
327 | return_hidden=False,
328 | stream=False,
329 | show_tqdm=True,
330 | ensure_non_empty=True,
331 | stream_batch=24,
332 | context=Context(),
333 | ):
334 |
335 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
336 | hiddens = []
337 | stream_iter = 0
338 |
339 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
340 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
341 | )
342 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
343 |
344 | old_temperature = temperature
345 |
346 | temperature = (
347 | temperature.unsqueeze(0)
348 | .expand(inputs_ids.shape[0], -1)
349 | .contiguous()
350 | .view(-1, 1)
351 | )
352 |
353 | attention_mask_cache = torch.ones(
354 | (
355 | inputs_ids.shape[0],
356 | inputs_ids.shape[1] + max_new_token,
357 | ),
358 | dtype=torch.bool,
359 | device=inputs_ids.device,
360 | )
361 | if attention_mask is not None:
362 | attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
363 | attention_mask
364 | )
365 |
366 | progress = inputs_ids.size(1)
367 | # pre-allocate inputs_ids
368 | inputs_ids_buf = torch.zeros(
369 | inputs_ids.size(0),
370 | progress + max_new_token,
371 | inputs_ids.size(2),
372 | dtype=inputs_ids.dtype,
373 | device=inputs_ids.device,
374 | )
375 | inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
376 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
377 |
378 | pbar: Optional[tqdm] = None
379 |
380 | if show_tqdm:
381 | pbar = tqdm(
382 | total=max_new_token,
383 | desc="text" if infer_text else "code",
384 | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
385 | )
386 |
387 | past_key_values = None
388 |
389 | for i in range(max_new_token):
390 |
391 | model_input = self._prepare_generation_inputs(
392 | inputs_ids,
393 | past_key_values,
394 | attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
395 | use_cache=not self.is_te_llama,
396 | )
397 |
398 | if i > 0:
399 | inputs_ids_emb = model_input.input_ids.to(emb.device)
400 | if infer_text:
401 | emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
402 | else:
403 | code_emb = [
404 | self.emb_code[i](inputs_ids_emb[:, :, i])
405 | for i in range(self.num_vq)
406 | ]
407 | emb = torch.stack(code_emb, 3).sum(3)
408 | model_input.inputs_embeds = emb
409 | model_input.to(emb.device, emb.dtype)
410 | outputs: BaseModelOutputWithPast = self.gpt(
411 | attention_mask=model_input.attention_mask,
412 | position_ids=model_input.position_ids,
413 | past_key_values=model_input.past_key_values,
414 | inputs_embeds=model_input.inputs_embeds,
415 | use_cache=model_input.use_cache,
416 | output_attentions=return_attn,
417 | cache_position=model_input.cache_position,
418 | )
419 | attentions.append(outputs.attentions)
420 | hidden_states = outputs.last_hidden_state
421 | past_key_values = outputs.past_key_values
422 | if return_hidden:
423 | hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
424 | with P.cached():
425 | if infer_text:
426 | logits: torch.Tensor = self.head_text(hidden_states)
427 | else:
428 | # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
429 | logits = torch.empty(
430 | hidden_states.size(0),
431 | hidden_states.size(1),
432 | self.num_audio_tokens,
433 | self.num_vq,
434 | dtype=emb.dtype,
435 | device=emb.device,
436 | )
437 | for num_vq_iter in range(self.num_vq):
438 | x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
439 | logits[..., num_vq_iter] = x
440 |
441 | # logits = logits[:, -1].float()
442 | logits = logits.narrow(1, -1, 1).squeeze_(1)
443 |
444 | if not infer_text:
445 | # logits = rearrange(logits, "b c n -> (b n) c")
446 | logits = logits.permute(0, 2, 1)
447 | logits = logits.reshape(-1, logits.size(2))
448 | # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
449 | inputs_ids_sliced = inputs_ids.narrow(
450 | 1,
451 | start_idx,
452 | inputs_ids.size(1) - start_idx,
453 | ).permute(0, 2, 1)
454 | logits_token = inputs_ids_sliced.reshape(
455 | inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
456 | -1,
457 | ).to(emb.device)
458 | else:
459 | logits_token = (
460 | inputs_ids.narrow(
461 | 1,
462 | start_idx,
463 | inputs_ids.size(1) - start_idx,
464 | )
465 | .narrow(2, 0, 1)
466 | .to(emb.device)
467 | )
468 |
469 | logits /= temperature
470 |
471 | for logitsProcessors in logits_processors:
472 | logits = logitsProcessors(logits_token, logits)
473 |
474 | for logitsWarpers in logits_warpers:
475 | logits = logitsWarpers(logits_token, logits)
476 |
477 | if i < min_new_token:
478 | logits[:, eos_token] = -torch.inf
479 |
480 | scores = F.softmax(logits, dim=-1)
481 | idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
482 |
483 | if not infer_text:
484 | # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
485 | idx_next = idx_next.view(-1, self.num_vq)
486 | finish_or = idx_next.eq(eos_token).any(1)
487 | finish.logical_or_(finish_or)
488 | inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
489 | else:
490 | finish_or = idx_next.eq(eos_token).any(1)
491 | finish.logical_or_(finish_or)
492 | inputs_ids_buf.narrow(1, progress, 1).copy_(
493 | idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
494 | )
495 |
496 | if i == 0 and finish.any():
497 | self.logger.info(
498 | "unexpected end at index %s" % str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
499 | )
500 | if ensure_non_empty:
501 | if show_tqdm:
502 | pbar.close()
503 | self.logger.info("regenerate in order to ensure non-empty")
504 | new_gen = self.generate(
505 | emb,
506 | inputs_ids,
507 | old_temperature,
508 | eos_token,
509 | attention_mask,
510 | max_new_token,
511 | min_new_token,
512 | logits_warpers,
513 | logits_processors,
514 | infer_text,
515 | return_attn,
516 | return_hidden,
517 | stream,
518 | show_tqdm,
519 | ensure_non_empty,
520 | stream_batch,
521 | context,
522 | )
523 | for result in new_gen:
524 | yield result
525 | return
526 |
527 | progress += 1
528 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
529 |
530 | not_finished = finish.logical_not().to(end_idx.device)
531 | end_idx.add_(not_finished.int())
532 | stream_iter += not_finished.any().int()
533 | if stream:
534 | if stream_iter > 0 and stream_iter % stream_batch == 0:
535 | self.logger.info("yield stream result, end: %d", end_idx)
536 | yield self._prepare_generation_outputs(
537 | inputs_ids,
538 | start_idx,
539 | end_idx,
540 | attentions,
541 | hiddens,
542 | infer_text,
543 | )
544 |
545 | if finish.all() or context.get():
546 | break
547 |
548 | if pbar is not None:
549 | pbar.update(1)
550 |
551 | if pbar is not None:
552 | pbar.close()
553 |
554 | if not finish.all():
555 | if context.get():
556 | self.logger.info("generation is interrupted")
557 | else:
558 | self.logger.info(
559 | f"incomplete result. hit max_new_token: {max_new_token}"
560 | )
561 |
562 | yield self._prepare_generation_outputs(
563 | inputs_ids,
564 | start_idx,
565 | end_idx,
566 | attentions,
567 | hiddens,
568 | infer_text,
569 | )
570 |
--------------------------------------------------------------------------------
/chattts_plus/models/processors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
4 |
5 |
6 | class CustomRepetitionPenaltyLogitsProcessorRepeat:
7 |
8 | def __init__(self, penalty: float, max_input_ids: int, past_window: int):
9 | if not isinstance(penalty, float) or not (penalty > 0):
10 | raise ValueError(
11 | f"`penalty` has to be a strictly positive float, but is {penalty}"
12 | )
13 |
14 | self.penalty = penalty
15 | self.max_input_ids = max_input_ids
16 | self.past_window = past_window
17 |
18 | def __call__(
19 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
20 | ) -> torch.FloatTensor:
21 | if input_ids.size(1) > self.past_window:
22 | input_ids = input_ids.narrow(1, -self.past_window, self.past_window)
23 | freq = F.one_hot(input_ids, scores.size(1)).sum(1)
24 | if freq.size(0) > self.max_input_ids:
25 | freq.narrow(
26 | 0, self.max_input_ids, freq.size(0) - self.max_input_ids
27 | ).zero_()
28 | alpha = torch.pow(self.penalty, freq)
29 | scores = scores.contiguous()
30 | inp = scores.multiply(alpha)
31 | oth = scores.divide(alpha)
32 | con = scores < 0
33 | out = torch.where(con, inp, oth)
34 | return out
35 |
36 |
37 | def gen_logits(
38 | num_code: int,
39 | top_P=0.7,
40 | top_K=20,
41 | repetition_penalty=1.0,
42 | ):
43 | logits_warpers = []
44 | if top_P is not None:
45 | logits_warpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
46 | if top_K is not None:
47 | logits_warpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
48 |
49 | logits_processors = []
50 | if repetition_penalty is not None and repetition_penalty != 1:
51 | logits_processors.append(
52 | CustomRepetitionPenaltyLogitsProcessorRepeat(
53 | repetition_penalty, num_code, 16
54 | )
55 | )
56 |
57 | return logits_warpers, logits_processors
58 |
--------------------------------------------------------------------------------
/chattts_plus/models/tokenizer.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
4 | """
5 | https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
6 | """
7 |
8 | from typing import List, Tuple, Optional
9 | import lzma
10 |
11 | import numpy as np
12 | import pybase16384 as b14
13 | import torch
14 | import torch.nn.functional as F
15 | from transformers import BertTokenizerFast
16 |
17 | from ..commons import logger
18 |
19 |
20 | class Tokenizer:
21 | def __init__(
22 | self, model_path, **kwargs
23 | ):
24 | self.logger = logger.get_logger(self.__class__.__name__)
25 | self.logger.info(f"loading Tokenizer pretrained model: {model_path}")
26 |
27 | tokenizer: BertTokenizerFast = torch.load(
28 | model_path, map_location="cpu", mmap=True
29 | )
30 | self._tokenizer = tokenizer
31 |
32 | # 设置特殊token
33 | self._tokenizer.eos_token = '[SEP]' # 使用BERT的默认结束符
34 | self._tokenizer.pad_token = '[PAD]' # 使用BERT的默认填充符
35 |
36 | # 获取或设置对应的token id
37 | if not hasattr(self._tokenizer, 'pad_token_id'):
38 | self._tokenizer.pad_token_id = self._tokenizer.convert_tokens_to_ids('[PAD]')
39 | if not hasattr(self._tokenizer, 'eos_token_id'):
40 | self._tokenizer.eos_token_id = self._tokenizer.convert_tokens_to_ids('[SEP]')
41 |
42 | self.len = len(tokenizer)
43 | self.spk_emb_ids = tokenizer.convert_tokens_to_ids("[spk_emb]")
44 | self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
45 | self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")
46 |
47 | self.decode = self._tokenizer.batch_decode
48 |
49 | @torch.inference_mode()
50 | def encode(
51 | self,
52 | text: List[str],
53 | num_vq: int,
54 | prompt_str: Optional[str] = None,
55 | device="cpu",
56 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
57 |
58 | input_ids_lst = []
59 | attention_mask_lst = []
60 | max_input_ids_len = -1
61 | max_attention_mask_len = -1
62 | prompt_size = 0
63 |
64 | prompt = self._decode_prompt(prompt_str) if prompt_str is not None else None
65 |
66 | if prompt is not None:
67 | assert prompt.size(0) == num_vq, "prompt dim 0 must equal to num_vq"
68 | prompt_size = prompt.size(1)
69 |
70 | # avoid random speaker embedding of tokenizer in the other dims
71 | for t in text:
72 | x = self._tokenizer.encode_plus(
73 | t, return_tensors="pt", add_special_tokens=False, padding=True
74 | )
75 | input_ids_lst.append(x["input_ids"].squeeze_(0))
76 | attention_mask_lst.append(x["attention_mask"].squeeze_(0))
77 | ids_sz = input_ids_lst[-1].size(0)
78 | if ids_sz > max_input_ids_len:
79 | max_input_ids_len = ids_sz
80 | attn_sz = attention_mask_lst[-1].size(0)
81 | if attn_sz > max_attention_mask_len:
82 | max_attention_mask_len = attn_sz
83 |
84 | if prompt is not None:
85 | max_input_ids_len += prompt_size
86 | max_attention_mask_len += prompt_size
87 |
88 | input_ids = torch.zeros(
89 | len(input_ids_lst),
90 | max_input_ids_len,
91 | device=device,
92 | dtype=input_ids_lst[0].dtype,
93 | )
94 | for i in range(len(input_ids_lst)):
95 | input_ids.narrow(0, i, 1).narrow(
96 | 1,
97 | max_input_ids_len - prompt_size - input_ids_lst[i].size(0),
98 | input_ids_lst[i].size(0),
99 | ).copy_(
100 | input_ids_lst[i]
101 | ) # left padding
102 |
103 | attention_mask = torch.zeros(
104 | len(attention_mask_lst),
105 | max_attention_mask_len,
106 | device=device,
107 | dtype=attention_mask_lst[0].dtype,
108 | )
109 | for i in range(len(attention_mask_lst)):
110 | attn = attention_mask.narrow(0, i, 1)
111 | attn.narrow(
112 | 1,
113 | max_attention_mask_len - prompt_size - attention_mask_lst[i].size(0),
114 | attention_mask_lst[i].size(0),
115 | ).copy_(
116 | attention_mask_lst[i]
117 | ) # left padding
118 | if prompt_size > 0:
119 | attn.narrow(
120 | 1,
121 | max_attention_mask_len - prompt_size,
122 | prompt_size,
123 | ).fill_(1)
124 |
125 | text_mask = attention_mask.bool()
126 | new_input_ids = input_ids.unsqueeze_(-1).expand(-1, -1, num_vq).clone()
127 |
128 | if prompt_size > 0:
129 | text_mask.narrow(1, max_input_ids_len - prompt_size, prompt_size).fill_(0)
130 | prompt_t = prompt.t().unsqueeze_(0).expand(new_input_ids.size(0), -1, -1)
131 | new_input_ids.narrow(
132 | 1,
133 | max_input_ids_len - prompt_size,
134 | prompt_size,
135 | ).copy_(prompt_t)
136 |
137 | return new_input_ids, attention_mask, text_mask
138 |
139 | @staticmethod
140 | def _decode_spk_emb(spk_emb: str) -> np.ndarray:
141 | return np.frombuffer(
142 | lzma.decompress(
143 | b14.decode_from_string(spk_emb),
144 | format=lzma.FORMAT_RAW,
145 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
146 | ),
147 | dtype=np.float16,
148 | ).copy()
149 |
150 | @torch.no_grad()
151 | def apply_spk_emb(
152 | self,
153 | emb: torch.Tensor,
154 | spk_emb,
155 | input_ids: torch.Tensor,
156 | device: torch.device,
157 | ):
158 | if isinstance(spk_emb, str):
159 | spk_emb_tensor = torch.from_numpy(self._decode_spk_emb(spk_emb))
160 | else:
161 | spk_emb_tensor = spk_emb
162 |
163 | n = (
164 | F.normalize(
165 | spk_emb_tensor,
166 | p=2.0,
167 | dim=0,
168 | eps=1e-12,
169 | )
170 | .to(emb.device, dtype=emb.dtype)
171 | .unsqueeze_(0)
172 | .expand(emb.size(0), -1)
173 | .unsqueeze_(1)
174 | .expand(emb.shape)
175 | )
176 | cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape)
177 | torch.where(cond, n, emb, out=emb)
178 | return emb
179 |
180 | @staticmethod
181 | @torch.no_grad()
182 | def _decode_prompt(prompt: str) -> torch.Tensor:
183 | dec = b14.decode_from_string(prompt)
184 | shp = np.frombuffer(dec[:4], dtype=" str:
198 | arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy()
199 | shp = arr.shape
200 | assert len(shp) == 2, "prompt must be a 2D tensor"
201 | s = b14.encode_to_string(
202 | np.array(shp, dtype=" str:
214 | arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
215 | s = b14.encode_to_string(
216 | lzma.compress(
217 | arr.tobytes(),
218 | format=lzma.FORMAT_RAW,
219 | filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
220 | ),
221 | )
222 | return s
223 |
--------------------------------------------------------------------------------
/chattts_plus/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/27 22:53
3 | # @Project : ChatTTSPlus
4 | # @FileName: __init__.py.py
5 |
--------------------------------------------------------------------------------
/chattts_plus/trt_models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/10/6 19:59
3 | # @Author : wenshao
4 | # @ProjectName: ChatTTSPlus
5 | # @FileName: __init__.py.py
6 |
7 | from .gpt_trt import GPT
--------------------------------------------------------------------------------
/chattts_plus/trt_models/base_model.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from .predictor import get_predictor
4 |
5 |
6 | class BaseModel:
7 | """
8 | 模型预测的基类
9 | """
10 |
11 | def __init__(self, **kwargs):
12 | self.kwargs = copy.deepcopy(kwargs)
13 | self.predictor = get_predictor(**self.kwargs)
14 | self.device = torch.cuda.current_device()
15 |
16 | if self.predictor is not None:
17 | self.input_shapes = self.predictor.input_spec()
18 | self.output_shapes = self.predictor.output_spec()
19 |
20 | def input_process(self, *data):
21 | """
22 | 输入预处理
23 | :return:
24 | """
25 | pass
26 |
27 | def output_process(self, *data):
28 | """
29 | 输出后处理
30 | :return:
31 | """
32 | pass
33 |
34 | def predict(self, *data):
35 | """
36 | 预测
37 | :return:
38 | """
39 | pass
40 |
41 | def __del__(self):
42 | """
43 | 删除实例
44 | :return:
45 | """
46 | if self.predictor is not None:
47 | del self.predictor
48 |
--------------------------------------------------------------------------------
/chattts_plus/trt_models/gpt_trt.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import platform
3 | from dataclasses import dataclass
4 | import logging
5 | from typing import Union, List, Optional, Tuple
6 | import gc
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.utils.parametrize as P
12 | from torch.nn.utils.parametrizations import weight_norm
13 | from tqdm import tqdm
14 | from transformers import LlamaConfig, LogitsWarper
15 | from transformers.cache_utils import Cache
16 | from transformers.modeling_outputs import BaseModelOutputWithPast
17 | from transformers.utils import is_flash_attn_2_available
18 |
19 | from .llama_trt_model import LlamaTRTModel
20 | from ..models.processors import CustomRepetitionPenaltyLogitsProcessorRepeat
21 | from ..commons import logger
22 |
23 |
24 | class GPT(nn.Module):
25 | def __init__(
26 | self,
27 | gpt_config: dict,
28 | num_audio_tokens: int = 626,
29 | num_text_tokens: int = 21178,
30 | num_vq=4,
31 | use_flash_attn=False,
32 | **kwargs
33 | ):
34 | super().__init__()
35 | self.logger = logger.get_logger(self.__class__.__name__)
36 | self.num_vq = num_vq
37 | self.num_audio_tokens = num_audio_tokens
38 |
39 | self.use_flash_attn = use_flash_attn
40 | self.gpt_config = gpt_config
41 | self.gpt, self.llama_config = self._build_llama(gpt_config, **kwargs)
42 | self.is_te_llama = False
43 | self.model_dim = int(gpt_config["hidden_size"])
44 | self.emb_code = nn.ModuleList(
45 | [
46 | nn.Embedding(
47 | num_audio_tokens,
48 | self.model_dim
49 | )
50 | for _ in range(num_vq)
51 | ],
52 | )
53 | self.emb_text = nn.Embedding(
54 | num_text_tokens, self.model_dim
55 | )
56 |
57 | self.head_text = weight_norm(
58 | nn.Linear(
59 | self.model_dim,
60 | num_text_tokens,
61 | bias=False,
62 | ),
63 | name="weight",
64 | )
65 | self.head_code = nn.ModuleList(
66 | [
67 | weight_norm(
68 | nn.Linear(
69 | self.model_dim,
70 | num_audio_tokens,
71 | bias=False,
72 | ),
73 | name="weight",
74 | )
75 | for _ in range(self.num_vq)
76 | ],
77 | )
78 | self.model_path = kwargs.get("model_path", None)
79 | if self.model_path:
80 | self.logger.info(f"loading GPT pretrained model: {self.model_path}")
81 | self.from_pretrained(self.model_path)
82 |
83 | def from_pretrained(self, file_path: str):
84 | self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True), strict=False)
85 |
86 | class Context:
87 | def __init__(self):
88 | self._interrupt = False
89 |
90 | def set(self, v: bool):
91 | self._interrupt = v
92 |
93 | def get(self) -> bool:
94 | return self._interrupt
95 |
96 | def _build_llama(
97 | self,
98 | config: dict,
99 | **kwargs
100 | ):
101 |
102 | if self.use_flash_attn and is_flash_attn_2_available():
103 | llama_config = LlamaConfig(
104 | **config,
105 | attn_implementation="flash_attention_2",
106 | )
107 | self.logger.info(
108 | "enabling flash_attention_2 may make gpt be even slower"
109 | )
110 | else:
111 | llama_config = LlamaConfig(**config)
112 |
113 | model = LlamaTRTModel(predict_type="trt", model_path=kwargs.get("trt_model_path"),
114 | output_max_shapes=kwargs.get("output_max_shapes"))
115 |
116 | return model, llama_config
117 |
118 | def __call__(
119 | self, input_ids: torch.Tensor, text_mask: torch.Tensor
120 | ) -> torch.Tensor:
121 | """
122 | get_emb
123 | """
124 | return super().__call__(input_ids, text_mask)
125 |
126 | def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor:
127 | """
128 | get_emb
129 | """
130 |
131 | emb_text: torch.Tensor = self.emb_text(
132 | input_ids[text_mask].narrow(1, 0, 1).squeeze_(1)
133 | )
134 |
135 | text_mask_inv = text_mask.logical_not()
136 | masked_input_ids: torch.Tensor = input_ids[text_mask_inv]
137 |
138 | emb_code = [
139 | self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)
140 | ]
141 | emb_code = torch.stack(emb_code, 2).sum(2)
142 |
143 | emb = torch.zeros(
144 | (input_ids.shape[:-1]) + (emb_text.shape[-1],),
145 | device=emb_text.device,
146 | dtype=emb_text.dtype,
147 | )
148 | emb[text_mask] = emb_text
149 | emb[text_mask_inv] = emb_code.to(emb.dtype)
150 | return emb
151 |
152 | @dataclass(repr=False, eq=False)
153 | class _GenerationInputs:
154 | position_ids: torch.Tensor
155 | cache_position: torch.Tensor
156 | use_cache: bool
157 | input_ids: Optional[torch.Tensor] = None
158 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
159 | attention_mask: Optional[torch.Tensor] = None
160 | inputs_embeds: Optional[torch.Tensor] = None
161 |
162 | def to(self, device: torch.device, dtype: torch.dtype):
163 | if self.attention_mask is not None:
164 | self.attention_mask = self.attention_mask.to(device, dtype=dtype)
165 | if self.position_ids is not None:
166 | self.position_ids = self.position_ids.to(device, dtype=dtype)
167 | if self.inputs_embeds is not None:
168 | self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
169 | if self.cache_position is not None:
170 | self.cache_position = self.cache_position.to(device, dtype=dtype)
171 |
172 | def _prepare_generation_inputs(
173 | self,
174 | input_ids: torch.Tensor,
175 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
176 | attention_mask: Optional[torch.Tensor] = None,
177 | inputs_embeds: Optional[torch.Tensor] = None,
178 | cache_position: Optional[torch.Tensor] = None,
179 | position_ids: Optional[torch.Tensor] = None,
180 | use_cache=True,
181 | ) -> _GenerationInputs:
182 | # With static cache, the `past_key_values` is None
183 | # TODO joao: standardize interface for the different Cache classes and remove of this if
184 | has_static_cache = False
185 |
186 | past_length = 0
187 | if past_key_values is not None:
188 | if isinstance(past_key_values, Cache):
189 | past_length = (
190 | int(cache_position[0])
191 | if cache_position is not None
192 | else past_key_values.get_seq_length()
193 | )
194 | max_cache_length = past_key_values.get_max_length()
195 | cache_length = (
196 | past_length
197 | if max_cache_length is None
198 | else min(max_cache_length, past_length)
199 | )
200 | # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
201 | else:
202 | cache_length = past_length = past_key_values[0][0].shape[2]
203 | max_cache_length = None
204 |
205 | # Keep only the unprocessed tokens:
206 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
207 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
208 | # input)
209 | if (
210 | attention_mask is not None
211 | and attention_mask.shape[1] > input_ids.shape[1]
212 | ):
213 | start = attention_mask.shape[1] - past_length
214 | input_ids = input_ids.narrow(1, -start, start)
215 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
216 | # input_ids based on the past_length.
217 | elif past_length < input_ids.shape[1]:
218 | input_ids = input_ids.narrow(
219 | 1, past_length, input_ids.size(1) - past_length
220 | )
221 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
222 |
223 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
224 | if (
225 | max_cache_length is not None
226 | and attention_mask is not None
227 | and cache_length + input_ids.shape[1] > max_cache_length
228 | ):
229 | attention_mask = attention_mask.narrow(
230 | 1, -max_cache_length, max_cache_length
231 | )
232 |
233 | if attention_mask is not None and position_ids is None:
234 | # create position_ids on the fly for batch generation
235 | position_ids = attention_mask.long().cumsum(-1) - 1
236 | position_ids.masked_fill_(attention_mask.eq(0), 1)
237 | if past_key_values is not None:
238 | position_ids = position_ids.narrow(
239 | 1, -input_ids.shape[1], input_ids.shape[1]
240 | )
241 |
242 | input_length = (
243 | position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
244 | )
245 | if cache_position is None:
246 | cache_position = torch.arange(
247 | past_length, past_length + input_length, device=input_ids.device
248 | )
249 | else:
250 | cache_position = cache_position.narrow(0, -input_length, input_length)
251 |
252 | if has_static_cache:
253 | past_key_values = None
254 |
255 | model_inputs = self._GenerationInputs(
256 | position_ids=position_ids,
257 | cache_position=cache_position,
258 | use_cache=use_cache,
259 | )
260 |
261 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
262 | if inputs_embeds is not None and past_key_values is None:
263 | model_inputs.inputs_embeds = inputs_embeds
264 | else:
265 | # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
266 | # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
267 | # TODO: use `next_tokens` directly instead.
268 | model_inputs.input_ids = input_ids.contiguous()
269 |
270 | model_inputs.past_key_values = past_key_values
271 | model_inputs.attention_mask = attention_mask
272 |
273 | return model_inputs
274 |
275 | @dataclass(repr=False, eq=False)
276 | class GenerationOutputs:
277 | ids: List[torch.Tensor]
278 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
279 | hiddens: List[torch.Tensor]
280 |
281 | def _prepare_generation_outputs(
282 | self,
283 | inputs_ids: torch.Tensor,
284 | start_idx: int,
285 | end_idx: torch.Tensor,
286 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
287 | hiddens: List[torch.Tensor],
288 | infer_text: bool,
289 | ) -> GenerationOutputs:
290 | inputs_ids = [
291 | inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
292 | ]
293 | if infer_text:
294 | inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]
295 |
296 | if len(hiddens) > 0:
297 | hiddens = torch.stack(hiddens, 1)
298 | hiddens = [
299 | hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
300 | ]
301 |
302 | return self.GenerationOutputs(
303 | ids=inputs_ids,
304 | attentions=attentions,
305 | hiddens=hiddens,
306 | )
307 |
308 | @torch.no_grad()
309 | def generate(
310 | self,
311 | emb: torch.Tensor,
312 | inputs_ids: torch.Tensor,
313 | temperature: torch.Tensor,
314 | eos_token: Union[int, torch.Tensor],
315 | attention_mask: Optional[torch.Tensor] = None,
316 | max_new_token=2048,
317 | min_new_token=0,
318 | logits_warpers: List[LogitsWarper] = [],
319 | logits_processors: List[CustomRepetitionPenaltyLogitsProcessorRepeat] = [],
320 | infer_text=False,
321 | return_attn=False,
322 | return_hidden=False,
323 | stream=False,
324 | show_tqdm=True,
325 | ensure_non_empty=True,
326 | stream_batch=24,
327 | context=Context(),
328 | ):
329 |
330 | attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
331 | hiddens = []
332 | stream_iter = 0
333 |
334 | start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
335 | inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
336 | )
337 | finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()
338 |
339 | old_temperature = temperature
340 |
341 | temperature = (
342 | temperature.unsqueeze(0)
343 | .expand(inputs_ids.shape[0], -1)
344 | .contiguous()
345 | .view(-1, 1)
346 | )
347 |
348 | attention_mask_cache = torch.ones(
349 | (
350 | inputs_ids.shape[0],
351 | inputs_ids.shape[1] + max_new_token,
352 | ),
353 | dtype=torch.bool,
354 | device=inputs_ids.device,
355 | )
356 | if attention_mask is not None:
357 | attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
358 | attention_mask
359 | )
360 |
361 | progress = inputs_ids.size(1)
362 | # pre-allocate inputs_ids
363 | inputs_ids_buf = torch.zeros(
364 | inputs_ids.size(0),
365 | progress + max_new_token,
366 | inputs_ids.size(2),
367 | dtype=inputs_ids.dtype,
368 | device=inputs_ids.device,
369 | )
370 | inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
371 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
372 |
373 | pbar: Optional[tqdm] = None
374 |
375 | if show_tqdm:
376 | pbar = tqdm(
377 | total=max_new_token,
378 | desc="text" if infer_text else "code",
379 | bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
380 | )
381 |
382 | past_key_values = None
383 | # need to create kv cache first
384 | self.gpt.create_kv_cache(inputs_ids.size(0))
385 | for i in range(max_new_token):
386 |
387 | model_input = self._prepare_generation_inputs(
388 | inputs_ids,
389 | past_key_values,
390 | attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
391 | use_cache=not self.is_te_llama,
392 | )
393 |
394 | if i > 0:
395 | inputs_ids_emb = model_input.input_ids.to(emb.device)
396 | if infer_text:
397 | emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
398 | else:
399 | code_emb = [
400 | self.emb_code[i](inputs_ids_emb[:, :, i])
401 | for i in range(self.num_vq)
402 | ]
403 | emb = torch.stack(code_emb, 3).sum(3)
404 | model_input.inputs_embeds = emb
405 | model_input.to(emb.device, dtype=emb.dtype)
406 | hidden_states = self.gpt.predict(model_input.inputs_embeds, model_input.attention_mask,
407 | model_input.position_ids)
408 | hidden_states = hidden_states.to(emb.device, dtype=emb.dtype)
409 | past_key_values = self.gpt.get_cur_kv_caches()
410 | attentions.append(None)
411 | if return_hidden:
412 | hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))
413 |
414 | with P.cached():
415 | if infer_text:
416 | logits: torch.Tensor = self.head_text(hidden_states)
417 | else:
418 | # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
419 | logits = torch.empty(
420 | hidden_states.size(0),
421 | hidden_states.size(1),
422 | self.num_audio_tokens,
423 | self.num_vq,
424 | dtype=emb.dtype,
425 | device=emb.device,
426 | )
427 | for num_vq_iter in range(self.num_vq):
428 | x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
429 | logits[..., num_vq_iter] = x
430 |
431 | # logits = logits[:, -1].float()
432 | logits = logits.narrow(1, -1, 1).squeeze_(1)
433 |
434 | if not infer_text:
435 | # logits = rearrange(logits, "b c n -> (b n) c")
436 | logits = logits.permute(0, 2, 1)
437 | logits = logits.reshape(-1, logits.size(2))
438 | # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
439 | inputs_ids_sliced = inputs_ids.narrow(
440 | 1,
441 | start_idx,
442 | inputs_ids.size(1) - start_idx,
443 | ).permute(0, 2, 1)
444 | logits_token = inputs_ids_sliced.reshape(
445 | inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
446 | -1,
447 | ).to(emb.device)
448 | else:
449 | logits_token = (
450 | inputs_ids.narrow(
451 | 1,
452 | start_idx,
453 | inputs_ids.size(1) - start_idx,
454 | )
455 | .narrow(2, 0, 1)
456 | .to(emb.device)
457 | )
458 |
459 | logits /= temperature
460 |
461 | for logitsProcessors in logits_processors:
462 | logits = logitsProcessors(logits_token, logits)
463 |
464 | for logitsWarpers in logits_warpers:
465 | logits = logitsWarpers(logits_token, logits)
466 |
467 | if i < min_new_token:
468 | logits[:, eos_token] = -torch.inf
469 |
470 | scores = F.softmax(logits, dim=-1)
471 | idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
472 |
473 | if not infer_text:
474 | # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
475 | idx_next = idx_next.view(-1, self.num_vq)
476 | finish_or = idx_next.eq(eos_token).any(1)
477 | finish.logical_or_(finish_or)
478 | inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
479 | else:
480 | finish_or = idx_next.eq(eos_token).any(1)
481 | finish.logical_or_(finish_or)
482 | inputs_ids_buf.narrow(1, progress, 1).copy_(
483 | idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
484 | )
485 |
486 | if i == 0 and finish.any():
487 | self.logger.info(
488 | "unexpected end at index %s" % str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
489 | )
490 | if ensure_non_empty:
491 | if show_tqdm:
492 | pbar.close()
493 | self.logger.info("regenerate in order to ensure non-empty")
494 | new_gen = self.generate(
495 | emb,
496 | inputs_ids,
497 | old_temperature,
498 | eos_token,
499 | attention_mask,
500 | max_new_token,
501 | min_new_token,
502 | logits_warpers,
503 | logits_processors,
504 | infer_text,
505 | return_attn,
506 | return_hidden,
507 | stream,
508 | show_tqdm,
509 | ensure_non_empty,
510 | stream_batch,
511 | context,
512 | )
513 | for result in new_gen:
514 | yield result
515 | return
516 |
517 | progress += 1
518 | inputs_ids = inputs_ids_buf.narrow(1, 0, progress)
519 |
520 | not_finished = finish.logical_not().to(end_idx.device)
521 | end_idx.add_(not_finished.int())
522 | stream_iter += not_finished.any().int()
523 | if stream:
524 | if stream_iter > 0 and stream_iter % stream_batch == 0:
525 | self.logger.info("yield stream result, end: %d", end_idx)
526 | yield self._prepare_generation_outputs(
527 | inputs_ids,
528 | start_idx,
529 | end_idx,
530 | attentions,
531 | hiddens,
532 | infer_text,
533 | )
534 |
535 | if finish.all() or context.get():
536 | break
537 |
538 | if pbar is not None:
539 | pbar.update(1)
540 |
541 | if pbar is not None:
542 | pbar.close()
543 |
544 | if not finish.all():
545 | if context.get():
546 | self.logger.info("generation is interrupted")
547 | else:
548 | self.logger.info(
549 | f"incomplete result. hit max_new_token: {max_new_token}"
550 | )
551 |
552 | yield self._prepare_generation_outputs(
553 | inputs_ids,
554 | start_idx,
555 | end_idx,
556 | attentions,
557 | hiddens,
558 | infer_text,
559 | )
560 |
--------------------------------------------------------------------------------
/chattts_plus/trt_models/llama_trt_model.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import time
3 | import numpy as np
4 | import torch
5 | from torch.cuda import nvtx
6 | from .base_model import BaseModel
7 | from .predictor import numpy_to_torch_dtype_dict
8 |
9 |
10 | class LlamaTRTModel(BaseModel):
11 | """
12 | Llama TensorRT Model
13 | """
14 |
15 | def __init__(self, **kwargs):
16 | super(LlamaTRTModel, self).__init__(**kwargs)
17 | self.predict_type = kwargs.get("predict_type", "trt")
18 | self.cudaStream = torch.cuda.current_stream().cuda_stream
19 | self.max_seq_len = kwargs.get("max_seq_len", 2048)
20 | for i, inp in enumerate(self.predictor.inputs):
21 | if inp["name"] == "past_key_values":
22 | self.kv_cache_dtype = numpy_to_torch_dtype_dict[inp['dtype']]
23 | self.kv_cache_shape = inp['shape']
24 |
25 | def create_kv_cache(self, batch_size=1):
26 | self.kv_caches = torch.empty(self.kv_cache_shape[0], self.kv_cache_shape[1], batch_size, self.kv_cache_shape[3],
27 | self.max_seq_len + 1, self.kv_cache_shape[5]).to(self.device,
28 | dtype=self.kv_cache_dtype)
29 | self.kv_ind = 1
30 |
31 | def clear_kv_caches(self):
32 | self.kv_ind = 1
33 |
34 | def get_cur_kv_caches(self):
35 | return self.kv_caches[:, :, :, :, 1:self.kv_ind]
36 |
37 | def input_process(self, *data, **kwargs):
38 | return data
39 |
40 | def output_process(self, *data, **kwargs):
41 | return data[0]
42 |
43 | def predict_trt(self, *data, **kwargs):
44 | nvtx.range_push("forward")
45 | feed_dict = {}
46 | cur_input_shape = None
47 | for i, inp in enumerate(self.predictor.inputs):
48 | if inp["name"] != "past_key_values":
49 | if inp["name"] == "inputs_embeds":
50 | cur_input_shape = data[i].shape
51 | if isinstance(data[i], torch.Tensor):
52 | feed_dict[inp['name']] = data[i].to(device=self.device,
53 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
54 | else:
55 | feed_dict[inp['name']] = torch.from_numpy(data[i]).to(device=self.device,
56 | dtype=numpy_to_torch_dtype_dict[inp['dtype']])
57 | else:
58 | feed_dict[inp['name']] = self.kv_caches[:, :, :, :, :self.kv_ind]
59 | preds_dict = self.predictor.predict(feed_dict, self.cudaStream)
60 | outs = []
61 | for i, out in enumerate(self.predictor.outputs):
62 | if out["name"] == "cur_key_values":
63 | out_shape = self.kv_cache_shape[:]
64 | out_shape[2] = cur_input_shape[0]
65 | out_shape[4] = cur_input_shape[1]
66 | out_tensor = preds_dict[out["name"]][:np.prod(out_shape)].reshape(*out_shape)
67 | new_kv_len = out_tensor.shape[4]
68 | self.kv_caches[:, :, :, :, self.kv_ind:self.kv_ind + new_kv_len] = out_tensor.clone()
69 | self.kv_ind += new_kv_len
70 | else:
71 | out_shape = cur_input_shape[:]
72 | out_tensor = preds_dict[out["name"]][:np.prod(out_shape)].reshape(*out_shape)
73 | outs.append(out_tensor.clone())
74 | nvtx.range_pop()
75 | return outs
76 |
77 | def predict(self, *data, **kwargs):
78 | data = self.input_process(*data, **kwargs)
79 | preds = self.predict_trt(*data, **kwargs)
80 | outputs = self.output_process(*preds, **kwargs)
81 | return outputs
82 |
--------------------------------------------------------------------------------
/chattts_plus/trt_models/predictor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch.cuda import nvtx
5 | from collections import OrderedDict
6 |
7 | try:
8 | import tensorrt as trt
9 | import ctypes
10 | except ModuleNotFoundError:
11 | print("No TensorRT Found")
12 |
13 | numpy_to_torch_dtype_dict = {
14 | np.uint8: torch.uint8,
15 | np.int8: torch.int8,
16 | np.int16: torch.int16,
17 | np.int32: torch.int32,
18 | np.int64: torch.int64,
19 | np.float16: torch.float16,
20 | np.float32: torch.float32,
21 | np.float64: torch.float64,
22 | np.complex64: torch.complex64,
23 | np.complex128: torch.complex128,
24 | }
25 | if np.version.full_version >= "1.24.0":
26 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool
27 | else:
28 | numpy_to_torch_dtype_dict[np.bool] = torch.bool
29 |
30 |
31 | class TensorRTPredictor:
32 | """
33 | Implements inference for the EfficientDet TensorRT engine.
34 | """
35 |
36 | def __init__(self, **kwargs):
37 | """
38 | :param engine_path: The path to the serialized engine to load from disk.
39 | """
40 | # Load TRT engine
41 | self.logger = trt.Logger(trt.Logger.ERROR)
42 | trt.init_libnvinfer_plugins(self.logger, "")
43 | engine_path = kwargs.get("model_path", None)
44 | self.debug = kwargs.get("debug", False)
45 | assert engine_path, f"model:{engine_path} must exist!"
46 | with open(engine_path, "rb") as f, trt.Runtime(self.logger) as runtime:
47 | assert runtime
48 | self.engine = runtime.deserialize_cuda_engine(f.read())
49 | assert self.engine
50 | # self.context = self.engine.create_execution_context()
51 | # assert self.context
52 |
53 | # Setup I/O bindings
54 | self.inputs = []
55 | self.outputs = []
56 | self.tensors = OrderedDict()
57 |
58 | # TODO: 支持动态shape输入
59 | for idx in range(self.engine.num_io_tensors):
60 | name = self.engine[idx]
61 | is_input = self.engine.get_tensor_mode(name).name == "INPUT"
62 | shape = self.engine.get_tensor_shape(name)
63 | dtype = trt.nptype(self.engine.get_tensor_dtype(name))
64 | binding = {
65 | "index": idx,
66 | "name": name,
67 | "dtype": dtype,
68 | "shape": list(shape)
69 | }
70 | if is_input:
71 | self.inputs.append(binding)
72 | else:
73 | self.outputs.append(binding)
74 |
75 | assert len(self.inputs) > 0
76 | assert len(self.outputs) > 0
77 | self.allocate_max_buffers(**kwargs)
78 | self.activate()
79 |
80 | def activate(self, reuse_device_memory=None):
81 | if reuse_device_memory:
82 | self.context = self.engine.create_execution_context_without_device_memory()
83 | else:
84 | self.context = self.engine.create_execution_context()
85 |
86 | def deactivate(self):
87 | if self.context:
88 | del self.context
89 | self.context = None
90 |
91 | def allocate_max_buffers(self, device="cuda", **kwargs):
92 | nvtx.range_push("allocate_max_buffers")
93 | # 目前仅支持 batch 维度的动态处理, 如果是其他维度的动态,请传入 output_max_shapes
94 | output_max_shapes = kwargs.get("output_max_shapes", {})
95 |
96 | batch_size = 1
97 | for idx in range(self.engine.num_io_tensors):
98 | binding = self.engine[idx]
99 | shape = self.engine.get_tensor_shape(binding)
100 | is_input = self.engine.get_tensor_mode(binding).name == "INPUT"
101 | if binding in output_max_shapes:
102 | shape = output_max_shapes[binding]
103 | else:
104 | if -1 in shape:
105 | if is_input:
106 | shape = self.engine.get_tensor_profile_shape(binding, 0)[-1]
107 | batch_size = shape[0]
108 | else:
109 | shape[0] = batch_size
110 | dtype = trt.nptype(self.engine.get_tensor_dtype(binding))
111 | tensor = torch.empty(
112 | np.prod(list(shape)), dtype=numpy_to_torch_dtype_dict[dtype]
113 | ).to(device=device)
114 | self.tensors[binding] = tensor
115 | nvtx.range_pop()
116 |
117 | def input_spec(self):
118 | """
119 | Get the specs for the input tensor of the network. Useful to prepare memory allocations.
120 | :return: Two items, the shape of the input tensor and its (numpy) datatype.
121 | """
122 | specs = []
123 | for i, o in enumerate(self.inputs):
124 | specs.append((o["name"], o['shape'], o['dtype']))
125 | if self.debug:
126 | print(f"trt input {i} -> {o['name']} -> {o['shape']}")
127 | return specs
128 |
129 | def output_spec(self):
130 | """
131 | Get the specs for the output tensors of the network. Useful to prepare memory allocations.
132 | :return: A list with two items per element, the shape and (numpy) datatype of each output tensor.
133 | """
134 | specs = []
135 | for i, o in enumerate(self.outputs):
136 | specs.append((o["name"], o['shape'], o['dtype']))
137 | if self.debug:
138 | print(f"trt output {i} -> {o['name']} -> {o['shape']}")
139 | return specs
140 |
141 | def adjust_buffer(self, feed_dict):
142 | nvtx.range_push("adjust_buffer")
143 | for name, buf in feed_dict.items():
144 | input_tensor = self.tensors[name]
145 | current_shape = list(buf.shape)
146 | if len(current_shape) == 0:
147 | current_shape = (1,)
148 | tensor_len = np.prod(current_shape)
149 | input_tensor[:tensor_len].copy_(buf.reshape(-1))
150 | self.context.set_input_shape(name, current_shape)
151 | nvtx.range_pop()
152 |
153 | def predict(self, feed_dict, stream):
154 | """
155 | Execute inference on a batch of images.
156 | :param data: A list of inputs as numpy arrays.
157 | :return A list of outputs as numpy arrays.
158 | """
159 | nvtx.range_push("set_tensors")
160 | self.adjust_buffer(feed_dict)
161 | for name, tensor in self.tensors.items():
162 | self.context.set_tensor_address(name, tensor.data_ptr())
163 | nvtx.range_pop()
164 | nvtx.range_push("execute")
165 | noerror = self.context.execute_async_v3(stream)
166 | if not noerror:
167 | raise ValueError("ERROR: inference failed.")
168 | nvtx.range_pop()
169 | return self.tensors
170 |
171 | def __del__(self):
172 | if self.engine is not None:
173 | del self.engine
174 | self.engine = None
175 | if self.context is not None:
176 | del self.context
177 | self.context = None
178 | del self.inputs
179 | del self.outputs
180 | del self.tensors
181 |
182 |
183 | def get_predictor(**kwargs):
184 | predict_type = kwargs.get("predict_type", "trt")
185 | if predict_type == "trt":
186 | return TensorRTPredictor(**kwargs)
187 | else:
188 | raise NotImplementedError
189 |
--------------------------------------------------------------------------------
/configs/accelerate/deepspeed_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | gradient_accumulation_steps: 2
5 | gradient_clipping: 1.0
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: false
9 | zero_stage: 2
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | enable_cpu_affinity: false
13 | machine_rank: 0
14 | main_training_function: main
15 | mixed_precision: bf16
16 | num_machines: 1
17 | num_processes: 1
18 | rdzv_backend: static
19 | same_network: true
20 | tpu_env: []
21 | tpu_use_cluster: false
22 | tpu_use_sudo: false
23 | use_cpu: false
--------------------------------------------------------------------------------
/configs/infer/chattts_plus.yaml:
--------------------------------------------------------------------------------
1 | MODELS:
2 | tokenizer:
3 | name: "Tokenizer"
4 | infer_type: "pytorch"
5 | kwargs:
6 | model_path: "checkpoints/asset/tokenizer.pt"
7 | dvae_encode:
8 | name: "DVAE"
9 | infer_type: "pytorch"
10 | kwargs:
11 | model_path: "checkpoints/asset/DVAE_full.pt"
12 | dim: 512
13 | decoder_config:
14 | idim: 512
15 | odim: 512
16 | hidden: 256
17 | n_layer: 12
18 | bn_dim: 128
19 | encoder_config:
20 | idim: 512
21 | odim: 1024
22 | hidden: 256
23 | n_layer: 12
24 | bn_dim: 128
25 | vq_config:
26 | dim: 1024
27 | levels:
28 | - 5
29 | - 5
30 | - 5
31 | - 5
32 | G: 2
33 | R: 2
34 | dvae_decode:
35 | name: "DVAE"
36 | infer_type: "pytorch"
37 | kwargs:
38 | model_path: "checkpoints/asset/Decoder.pt"
39 | dim: 384
40 | decoder_config:
41 | idim: 384
42 | odim: 384
43 | hidden: 512
44 | n_layer: 12
45 | bn_dim: 128
46 | vocos:
47 | name: "Vocos"
48 | infer_type: "pytorch"
49 | kwargs:
50 | model_path: "checkpoints/asset/Vocos.pt"
51 | feature_extractor_config:
52 | sample_rate: 24000
53 | n_fft: 1024
54 | hop_length: 256
55 | n_mels: 100
56 | padding: "center"
57 | backbone_config:
58 | input_channels: 100
59 | dim: 512
60 | intermediate_dim: 1536
61 | num_layers: 8
62 | head_config:
63 | dim: 512
64 | n_fft: 1024
65 | hop_length: 256
66 | padding: "center"
67 | gpt:
68 | name: "GPT"
69 | infer_type: "pytorch"
70 | kwargs:
71 | model_path: "checkpoints/asset/GPT.pt"
72 | gpt_config:
73 | hidden_size: 768
74 | intermediate_size: 3072
75 | num_attention_heads: 12
76 | num_hidden_layers: 20
77 | use_cache: False
78 | max_position_embeddings: 4096
79 | spk_emb_dim: 192
80 | spk_KL: False
81 | num_audio_tokens: 626
82 | num_vq: 4
83 |
84 |
85 |
--------------------------------------------------------------------------------
/configs/infer/chattts_plus_trt.yaml:
--------------------------------------------------------------------------------
1 | MODELS:
2 | tokenizer:
3 | name: "Tokenizer"
4 | infer_type: "pytorch"
5 | kwargs:
6 | model_path: "checkpoints/asset/tokenizer.pt"
7 | dvae_encode:
8 | name: "DVAE"
9 | infer_type: "pytorch"
10 | kwargs:
11 | model_path: "checkpoints/asset/DVAE_full.pt"
12 | dim: 512
13 | decoder_config:
14 | idim: 512
15 | odim: 512
16 | hidden: 256
17 | n_layer: 12
18 | bn_dim: 128
19 | encoder_config:
20 | idim: 512
21 | odim: 1024
22 | hidden: 256
23 | n_layer: 12
24 | bn_dim: 128
25 | vq_config:
26 | dim: 1024
27 | levels:
28 | - 5
29 | - 5
30 | - 5
31 | - 5
32 | G: 2
33 | R: 2
34 | dvae_decode:
35 | name: "DVAE"
36 | infer_type: "pytorch"
37 | kwargs:
38 | model_path: "checkpoints/asset/Decoder.pt"
39 | dim: 384
40 | decoder_config:
41 | idim: 384
42 | odim: 384
43 | hidden: 512
44 | n_layer: 12
45 | bn_dim: 128
46 | vocos:
47 | name: "Vocos"
48 | infer_type: "pytorch"
49 | kwargs:
50 | model_path: "checkpoints/asset/Vocos.pt"
51 | feature_extractor_config:
52 | sample_rate: 24000
53 | n_fft: 1024
54 | hop_length: 256
55 | n_mels: 100
56 | padding: "center"
57 | backbone_config:
58 | input_channels: 100
59 | dim: 512
60 | intermediate_dim: 1536
61 | num_layers: 8
62 | head_config:
63 | dim: 512
64 | n_fft: 1024
65 | hop_length: 256
66 | padding: "center"
67 | gpt:
68 | name: "GPT"
69 | infer_type: "trt"
70 | kwargs:
71 | model_path: "checkpoints/asset/GPT.pt"
72 | trt_model_path: "checkpoints/chattts_llama.trt"
73 | output_max_shapes:
74 | hidden_states:
75 | - 4
76 | - 2048
77 | - 768
78 | cur_key_values:
79 | - 20
80 | - 2
81 | - 4
82 | - 12
83 | - 2048
84 | - 64
85 | gpt_config:
86 | hidden_size: 768
87 | intermediate_size: 3072
88 | num_attention_heads: 12
89 | num_hidden_layers: 20
90 | use_cache: False
91 | max_position_embeddings: 4096
92 | spk_emb_dim: 192
93 | spk_KL: False
94 | num_audio_tokens: 626
95 | num_vq: 4
96 |
97 |
98 |
--------------------------------------------------------------------------------
/configs/train/train_speaker_embedding.yaml:
--------------------------------------------------------------------------------
1 | MODELS:
2 | tokenizer:
3 | name: "Tokenizer"
4 | infer_type: "pytorch"
5 | kwargs:
6 | model_path: "checkpoints/asset/tokenizer.pt"
7 | dvae_encode:
8 | name: "DVAE"
9 | infer_type: "pytorch"
10 | kwargs:
11 | model_path: "checkpoints/asset/DVAE_full.pt"
12 | coef: ''
13 | dim: 512
14 | decoder_config:
15 | idim: 512
16 | odim: 512
17 | hidden: 256
18 | n_layer: 12
19 | bn_dim: 128
20 | encoder_config:
21 | idim: 512
22 | odim: 1024
23 | hidden: 256
24 | n_layer: 12
25 | bn_dim: 128
26 | vq_config:
27 | dim: 1024
28 | levels:
29 | - 5
30 | - 5
31 | - 5
32 | - 5
33 | G: 2
34 | R: 2
35 | dvae_decode:
36 | name: "DVAE"
37 | infer_type: "pytorch"
38 | kwargs:
39 | model_path: "checkpoints/asset/Decoder.pt"
40 | coef: ''
41 | dim: 384
42 | decoder_config:
43 | idim: 384
44 | odim: 384
45 | hidden: 512
46 | n_layer: 12
47 | bn_dim: 128
48 | gpt:
49 | name: "GPT"
50 | infer_type: "pytorch"
51 | kwargs:
52 | model_path: "checkpoints/asset/GPT.pt"
53 | gpt_config:
54 | hidden_size: 768
55 | intermediate_size: 3072
56 | num_attention_heads: 12
57 | num_hidden_layers: 20
58 | use_cache: False
59 | max_position_embeddings: 4096
60 | spk_emb_dim: 192
61 | spk_KL: False
62 | num_audio_tokens: 626
63 | num_vq: 4
64 |
65 | DATA:
66 | train_bs: 4
67 | meta_infos:
68 | - "data/xionger/slicer_opt.list"
69 | sample_rate: 24000
70 | num_vq: 4
71 |
72 | solver:
73 | gradient_accumulation_steps: 1
74 | mixed_precision: 'fp16'
75 | gradient_checkpointing: false
76 | max_train_steps: 40000
77 | max_grad_norm: 1.0
78 | # lr
79 | learning_rate: 1e-2
80 | scale_lr: false
81 | lr_warmup_steps: 10
82 | lr_scheduler: 'constant'
83 |
84 | # optimizer
85 | use_8bit_adam: false
86 |
87 | weight_dtype: 'fp16'
88 | output_dir: './outputs'
89 | exp_name: "xionger_speaker_emb"
90 | speaker_embeds_path: "outputs/xionger_speaker_emb-1732894509.8451874/checkpoints/step-4000/speaker_embeds.pkl"
91 | checkpointing_steps: 100
92 | use_empty_speaker: false
--------------------------------------------------------------------------------
/configs/train/train_voice_clone_lora.yaml:
--------------------------------------------------------------------------------
1 | MODELS:
2 | tokenizer:
3 | name: "Tokenizer"
4 | infer_type: "pytorch"
5 | kwargs:
6 | model_path: "checkpoints/asset/tokenizer.pt"
7 | dvae_encode:
8 | name: "DVAE"
9 | infer_type: "pytorch"
10 | kwargs:
11 | model_path: "checkpoints/asset/DVAE_full.pt"
12 | coef: ''
13 | dim: 512
14 | decoder_config:
15 | idim: 512
16 | odim: 512
17 | hidden: 256
18 | n_layer: 12
19 | bn_dim: 128
20 | encoder_config:
21 | idim: 512
22 | odim: 1024
23 | hidden: 256
24 | n_layer: 12
25 | bn_dim: 128
26 | vq_config:
27 | dim: 1024
28 | levels:
29 | - 5
30 | - 5
31 | - 5
32 | - 5
33 | G: 2
34 | R: 2
35 | dvae_decode:
36 | name: "DVAE"
37 | infer_type: "pytorch"
38 | kwargs:
39 | model_path: "checkpoints/asset/Decoder.pt"
40 | coef: ''
41 | dim: 384
42 | decoder_config:
43 | idim: 384
44 | odim: 384
45 | hidden: 512
46 | n_layer: 12
47 | bn_dim: 128
48 | gpt:
49 | name: "GPT"
50 | infer_type: "pytorch"
51 | kwargs:
52 | model_path: "checkpoints/asset/GPT.pt"
53 | gpt_config:
54 | hidden_size: 768
55 | intermediate_size: 3072
56 | num_attention_heads: 12
57 | num_hidden_layers: 20
58 | use_cache: False
59 | max_position_embeddings: 4096
60 | spk_emb_dim: 192
61 | spk_KL: False
62 | num_audio_tokens: 626
63 | num_vq: 4
64 |
65 | DATA:
66 | train_bs: 4
67 | meta_infos:
68 | - "data/leijun/asr_opt/denoise_opt_new.list"
69 | sample_rate: 24000
70 | num_vq: 4
71 |
72 | LORA:
73 | lora_r: 8
74 | lora_alpha: 16
75 | lora_dropout: 0.01
76 | lora_target_modules:
77 | - "q_proj"
78 | - "v_proj"
79 | - "k_proj"
80 | - "o_proj"
81 | # - "gate_proj"
82 | # - "up_proj"
83 | # - "down_proj"
84 |
85 | solver:
86 | gradient_accumulation_steps: 1
87 | mixed_precision: 'fp16'
88 | gradient_checkpointing: false
89 | max_train_steps: 4000
90 | max_grad_norm: 1.0
91 | # lr
92 | learning_rate: 5e-5
93 | min_learning_rate: 1e-5
94 | scale_lr: false
95 | lr_warmup_steps: 10
96 | lr_scheduler: 'constant'
97 |
98 | # optimizer
99 | use_8bit_adam: false
100 | adam_beta1: 0.9
101 | adam_beta2: 0.95
102 | adam_weight_decay: 1e-3
103 |
104 | weight_dtype: 'fp16'
105 | output_dir: './outputs'
106 | exp_name: "leijun_lora"
107 | lora_model_path: ""
108 | checkpointing_steps: 100
109 | use_empty_speaker: true
110 |
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/extract_files_to_texts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/3
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: extract_files_to_texts.py
7 |
8 | import datetime
9 | import os
10 | import pdb
11 |
12 | import fitz
13 | from tqdm import tqdm
14 | import os
15 | import cv2
16 | from paddleocr import PPStructure, save_structure_res
17 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
18 | from copy import deepcopy
19 |
20 |
21 | def pdf2images(pdf_file):
22 | image_dir = os.path.splitext(pdf_file)[0] + "_images"
23 | os.makedirs(image_dir, exist_ok=True)
24 |
25 | pdfDoc = fitz.open(pdf_file)
26 | totalPage = pdfDoc.page_count
27 | for pg in tqdm(range(totalPage)):
28 | page = pdfDoc[pg]
29 | rotate = int(0)
30 | zoom_x = 2
31 | zoom_y = 2
32 | mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
33 | pix = page.get_pixmap(matrix=mat, alpha=False)
34 | pix.save(os.path.join(image_dir, f"page_{pg + 1:03d}.png"))
35 | print(f"save images of pdf at: {image_dir}")
36 | return image_dir
37 |
38 |
39 | def extract_pdf_to_txt(pdf_file, save_txt_file=None, lang='ch'):
40 | # save_folder = os.path.splitext(pdf_file)[0] + "_txts"
41 | # os.makedirs(save_folder, exist_ok=True)
42 | table_engine = PPStructure(recovery=True, lang=lang, show_log=False)
43 | if save_txt_file is None:
44 | save_txt_file = os.path.splitext(pdf_file)[0] + ".txt"
45 | pdf_image_dir = pdf2images(pdf_file)
46 | text = []
47 | imgs = sorted(os.listdir(pdf_image_dir))
48 | for img_name in tqdm(imgs, total=len(imgs)):
49 | img = cv2.imread(os.path.join(pdf_image_dir, img_name))
50 | result = table_engine(img)
51 | # save_structure_res(result, save_folder, os.path.splitext(img_name)[0])
52 | h, w, _ = img.shape
53 | res = sorted_layout_boxes(result, w)
54 | # convert_info_docx(img, res, save_folder, os.path.splitext(img_name)[0])
55 | for line in res:
56 | line.pop('img')
57 | for pra in line['res']:
58 | if isinstance(pra, str):
59 | text.append(pra)
60 | else:
61 | text.append(pra['text'])
62 | text.append('\n')
63 | with open(save_txt_file, 'w', encoding='utf-8') as f:
64 | f.write('\n'.join(text))
65 | print(f"save txt of pdf at: {save_txt_file}")
66 | return save_txt_file
67 |
68 |
69 | def extract_image_to_txt(image_file, save_txt_file=None, lang='ch'):
70 | save_folder = os.path.splitext(pdf_file)[0] + "_txts"
71 | os.makedirs(save_folder, exist_ok=True)
72 | table_engine = PPStructure(recovery=True, lang=lang, show_log=False)
73 | if save_txt_file is None:
74 | save_txt_file = os.path.splitext(pdf_file)[0] + ".txt"
75 | if os.path.isdir(image_file):
76 | imgs = [os.path.join(image_file, img_) for img_ in os.listdir(image_file) if
77 | img_.split()[-1].lower() in ["jpg", "png", "jpeg"]]
78 | imgs = sorted(imgs)
79 | else:
80 | imgs = [image_file]
81 | text = []
82 | for img_path in tqdm(imgs, total=len(imgs)):
83 | img = cv2.imread(img_path)
84 | result = table_engine(img)
85 | # save_structure_res(result, save_folder, os.path.splitext(img_name)[0])
86 | h, w, _ = img.shape
87 | res = sorted_layout_boxes(result, w)
88 | # convert_info_docx(img, res, save_folder, os.path.splitext(img_name)[0])
89 | for line in res:
90 | line.pop('img')
91 | for pra in line['res']:
92 | text.append(pra['text'])
93 | text.append('\n')
94 | with open(save_txt_file, 'w', encoding='utf-8') as f:
95 | f.write('\n'.join(text))
96 | print(f"save txt of pdf at: {save_txt_file}")
97 | return save_txt_file
98 |
99 |
100 | if __name__ == '__main__':
101 | pdf_file = "../../data/pdfs/AnimateAnyone.pdf"
102 | # pdf2images(pdf_file)
103 | extract_pdf_to_txt(pdf_file, lang='en')
104 |
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/llm_api.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/3
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: llm_api.py
7 |
8 | import os
9 | import json
10 | import copy
11 | import base64
12 | import pdb
13 |
14 | from openai import OpenAI
15 |
16 |
17 | def encode_image(image_path):
18 | with open(image_path, "rb") as image_file:
19 | return base64.b64encode(image_file.read()).decode('utf-8')
20 |
21 |
22 | def init_script_writer_prompt(speaker1_gender="woman", speaker2_gender="man"):
23 | operation_history = []
24 | sysetm_prompt = f"""
25 | You are a renowned podcast scriptwriter, having worked as a ghostwriter for personalities like Joe Rogan, Lex Fridman, Ben Shapiro, and Tim Ferris. Your task is to create a dialogue based on a provided text by user,
26 | incorporating interjections such as "umm," "hmmm," and "right" from the second speaker. The conversation should be highly engaging; occasional digressions are acceptable,
27 | but the discussion should primarily revolve around the main topic.
28 |
29 | The dialogue involves two speakers:
30 |
31 | Speaker 1: This person, a {speaker1_gender}, leads the conversation, teaching Speaker 2 and providing vivid anecdotes and analogies during explanations. They are a charismatic teacher known for their compelling storytelling.
32 |
33 | Speaker 2: This person, a {speaker2_gender}, keeps the conversation flowing with follow-up questions, exhibiting high levels of excitement or confusion. Their inquisitive nature leads them to ask interesting confirmation questions.
34 |
35 | Ensure that any tangents introduced by Speaker 2 are intriguing or entertaining. Interruptions during explanations and interjections like "hmm" and "umm" from Speaker 2 should be present throughout the dialogue.
36 |
37 | Your output should resemble a genuine podcast, with every minute nuance documented in as much detail as possible. Begin with a captivating overview to welcome the listeners, keeping it intriguing and bordering on clickbait.
38 |
39 | Always commence your response with Speaker 1. Do not provide separate episode titles; instead, let Speaker 1 incorporate it into their dialogue. There should be no chapter titles; strictly provide the dialogues.
40 | """
41 | operation_history.append(["system", [{"type": "text", "text": sysetm_prompt}]])
42 | return operation_history
43 |
44 |
45 | def init_script_rewriter_prompt():
46 | operation_history = []
47 | system_prompt = """
48 | As an Oscar-winning international screenwriter, you've collaborated with many award-winning podcasters. Your current task is to rewrite the following podcast transcript for an AI Text-To-Speech Pipeline. A rudimentary AI initially wrote this transcript, and it now requires your expertise to make it engaging.
49 |
50 | Two different voice engines will simulate Speaker 1 and Speaker 2.
51 |
52 | Speaker 1: This character leads the conversation and educates Speaker 2. They are known for providing captivating teachings, enriched with compelling anecdotes and analogies.
53 |
54 | Speaker 2: This character maintains the conversation's flow by asking follow-up questions. They exhibit high levels of excitement or confusion and have a curious mindset, often asking intriguing confirmation questions.
55 |
56 | Speaker 2's tangents should be wild or interesting. Interruptions during explanations and interjections like "hmm" and "umm" from Speaker 2 should be present throughout the dialogue.
57 |
58 | Note: The Text-To-Speech engine for Speaker 1 struggles with "umms" and "hmms," so maintain a straight text for this speaker. For Speaker 2, feel free to use "umm," "hmm," [sigh], and [laughs] to convey expressions.
59 |
60 | The output should resemble a genuine podcast, with every minute detail documented meticulously. Begin with a captivating overview to welcome the listeners, keeping it intriguing and bordering on clickbait.
61 |
62 | Your response should start directly with Speaker 1 and strictly be returned as a list of tuples. No additional text should be included outside of the list.
63 |
64 | Example of response:
65 | [
66 | ("Speaker 1", "Welcome to our podcast, where we explore the latest advancements in AI and technology. I'm your host, and today we're joined by a renowned expert in the field of AI. We're going to dive into the exciting world of Llama 3.2, the latest release from Meta AI."),
67 | ("Speaker 2", "Hi, I'm excited to be here! So, what is Llama 3.2?"),
68 | ("Speaker 1", "Ah, great question! Llama 3.2 is an open-source AI model that allows developers to fine-tune, distill, and deploy AI models anywhere. It's a significant update from the previous version, with improved performance, efficiency, and customization options."),
69 | ("Speaker 2", "That sounds amazing! What are some of the key features of Llama 3.2?")
70 | ]
71 | """
72 | operation_history.append(["system", [{"type": "text", "text": system_prompt}]])
73 | return operation_history
74 |
75 |
76 | def inference_openai_chat(messages, model, api_url, token, max_tokens=2048, temperature=0.4, seed=1234):
77 | client = OpenAI(
78 | base_url=api_url,
79 | api_key=token,
80 | )
81 |
82 | data = {
83 | "model": model,
84 | "messages": [],
85 | "max_tokens": max_tokens,
86 | 'temperature': temperature,
87 | "seed": seed
88 | }
89 |
90 | for role, content in messages:
91 | data["messages"].append({"role": role, "content": content})
92 |
93 | completion = client.chat.completions.create(
94 | **data
95 | )
96 | return completion.choices[0].message.content
97 |
98 |
99 | def add_response(role, prompt, chat_history, image=None):
100 | new_chat_history = copy.deepcopy(chat_history)
101 | if image:
102 | base64_image = encode_image(image)
103 | content = [
104 | {
105 | "type": "text",
106 | "text": prompt
107 | },
108 | {
109 | "type": "image_url",
110 | "image_url": {
111 | "url": f"data:image/jpeg;base64,{base64_image}"
112 | }
113 | },
114 | ]
115 | else:
116 | content = [
117 | {
118 | "type": "text",
119 | "text": prompt
120 | },
121 | ]
122 | new_chat_history.append([role, content])
123 | return new_chat_history
124 |
125 |
126 | if __name__ == '__main__':
127 | import utils
128 | import pickle
129 |
130 | base_url = ""
131 | api_token = ""
132 | gpt_model = "gpt-4o-mini"
133 | pdf_txt_file = "../../data/pdfs/AnimateAnyone.txt"
134 | script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script.pkl"
135 | re_script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script-rewrite.pkl"
136 |
137 | chat_writer = init_script_writer_prompt()
138 | pdf_texts = utils.read_file_to_string(pdf_txt_file)
139 | chat_writer = add_response("user", pdf_texts, chat_writer)
140 | output_writer_texts = []
141 | output_writer = inference_openai_chat(chat_writer, gpt_model, base_url, api_token, max_tokens=8192)
142 | print(output_writer)
143 | chat_writer = add_response("assistant", output_writer, chat_writer)
144 |
145 | with open(script_pkl, 'wb') as file:
146 | pickle.dump(output_writer, file)
147 |
148 | with open(script_pkl, 'rb') as file:
149 | script_texts = pickle.load(file)
150 | chat_rewriter = init_script_rewriter_prompt()
151 | chat_rewriter = add_response("user", script_texts, chat_rewriter)
152 | output_rewriter = inference_openai_chat(chat_rewriter, gpt_model, base_url, api_token, max_tokens=8192)
153 | print(output_rewriter)
154 |
155 | with open(re_script_pkl, 'wb') as file:
156 | pickle.dump(output_rewriter, file)
157 |
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/requirements.txt:
--------------------------------------------------------------------------------
1 | openai
2 | PyMuPDF
3 | paddlepaddle
4 | paddleocr
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_man_5200.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_5200.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_man_8200.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_8200.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_man_9400.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_9400.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_man_9500.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_man_9500.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_woman_1200.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_1200.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_woman_4600.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_4600.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/en_woman_5600.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/en_woman_5600.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_man_1888.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_1888.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_man_2155.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_2155.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_man_54.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_man_54.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_woman_1528.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_1528.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_woman_492.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_492.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/speaker_pt/zh_woman_621.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/warmshao/ChatTTSPlus/5ada59126fe2e7b61c78f75dd2b7190d44cb5d8f/demos/notebooklm-podcast/speaker_pt/zh_woman_621.pt
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/tts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/3
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: tts.py
7 | import pdb
8 | import random
9 | import torch
10 | import math
11 | import numpy as np
12 | import ast
13 | import torchaudio
14 | from omegaconf import OmegaConf
15 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
16 | from chattts_plus.commons import utils as c_utils
17 |
18 |
19 | cfg = "../../configs/infer/chattts_plus_trt.yaml"
20 | infer_cfg = OmegaConf.load(cfg)
21 | pipe = ChatTTSPlusPipeline(infer_cfg, device=c_utils.get_inference_device())
22 |
23 |
24 | def generate_audio(
25 | text,
26 | speaker_emb_path,
27 | **kwargs
28 | ):
29 | if not text:
30 | return None
31 | audio_save_path = kwargs.get("audio_save_path", None)
32 | params_infer_code = c_utils.InferCodeParams(
33 | prompt="[speed_3]",
34 | temperature=.0003,
35 | max_new_token=2048,
36 | top_P=0.7,
37 | top_K=20
38 | )
39 | params_refine_text = c_utils.RefineTextParams(
40 | prompt='[oral_2][laugh_0][break_4]',
41 | top_P=0.7,
42 | top_K=20,
43 | temperature=0.3,
44 | max_new_token=384
45 | )
46 | infer_seed = kwargs.get("infer_seed", 1234)
47 | with c_utils.TorchSeedContext(infer_seed):
48 | pipe_res_gen = pipe.infer(
49 | text,
50 | params_refine_text=params_refine_text,
51 | params_infer_code=params_infer_code,
52 | use_decoder=True,
53 | stream=False,
54 | skip_refine_text=True,
55 | do_text_normalization=True,
56 | do_homophone_replacement=True,
57 | do_text_optimization=False,
58 | speaker_emb_path=speaker_emb_path,
59 | speaker_audio_path=None,
60 | speaker_audio_text=None
61 | )
62 | wavs = []
63 | for wavs_ in pipe_res_gen:
64 | wavs.extend(wavs_)
65 | wavs = torch.cat(wavs).cpu().float().unsqueeze(0)
66 | if audio_save_path:
67 | torchaudio.save(audio_save_path, wavs, 24000)
68 | return wavs
69 |
70 |
71 | if __name__ == '__main__':
72 | import pickle
73 | import os
74 | from tqdm import tqdm
75 |
76 | base_url = "https://api2.aigcbest.top/v1"
77 | api_token = ""
78 | gpt_model = ""
79 | pdf_txt_file = "../../data/pdfs/AnimateAnyone.txt"
80 | script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script.pkl"
81 | re_script_pkl = os.path.splitext(pdf_txt_file)[0] + "-script-rewrite.pkl"
82 | save_audio_path = os.path.splitext(pdf_txt_file)[0] + ".wav"
83 | with open(re_script_pkl, 'rb') as file:
84 | PODCAST_TEXT = pickle.load(file)
85 |
86 | final_audio = None
87 |
88 | i = 1
89 | speaker1_emb_path = "speaker_pt/en_woman_1200.pt"
90 | speaker2_emb_path = "speaker_pt/en_man_5200.pt"
91 | save_dir = os.path.splitext(pdf_txt_file)[0] + "_audios"
92 | os.makedirs(save_dir, exist_ok=True)
93 | wavs = []
94 | for speaker, text in tqdm(ast.literal_eval(PODCAST_TEXT), desc="Generating podcast segments", unit="segment"):
95 | # output_path = os.path.join(save_dir, f"_podcast_segment_{i:03d}.wav")
96 | output_path = None
97 | if speaker == "Speaker 1":
98 | wav_ = generate_audio(text, speaker1_emb_path, audio_save_path=output_path)
99 | else:
100 | wav_ = generate_audio(text, speaker2_emb_path, audio_save_path=output_path)
101 | wavs.append(wav_)
102 | i += 1
103 | wavs = torch.cat(wavs, dim=-1)
104 | torchaudio.save(save_audio_path, wavs, 24000)
105 | print(save_audio_path)
106 |
--------------------------------------------------------------------------------
/demos/notebooklm-podcast/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/3
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: utils.py
7 |
8 | def read_file_to_string(filename):
9 | # Try UTF-8 first (most common encoding for text files)
10 | try:
11 | with open(filename, 'r', encoding='utf-8') as file:
12 | content = file.read()
13 | return content
14 | except UnicodeDecodeError:
15 | # If UTF-8 fails, try with other common encodings
16 | encodings = ['latin-1', 'cp1252', 'iso-8859-1']
17 | for encoding in encodings:
18 | try:
19 | with open(filename, 'r', encoding=encoding) as file:
20 | content = file.read()
21 | print(f"Successfully read file using {encoding} encoding.")
22 | return content
23 | except UnicodeDecodeError:
24 | continue
25 |
26 | print(f"Error: Could not decode file '{filename}' with any common encoding.")
27 | return None
28 | except FileNotFoundError:
29 | print(f"Error: File '{filename}' not found.")
30 | return None
31 | except IOError:
32 | print(f"Error: Could not read file '{filename}'.")
33 | return None
34 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy<2.0.0
2 | numba
3 | torch>=2.1.0
4 | torchaudio
5 | tqdm
6 | vector_quantize_pytorch==1.17.8
7 | transformers>=4.41.1
8 | vocos
9 | IPython
10 | gradio
11 | pybase16384
12 | pynini==2.1.5; sys_platform == 'linux'
13 | WeTextProcessing; sys_platform == 'linux'
14 | nemo_text_processing; sys_platform == 'linux'
15 | av
16 | pydub
17 | pandas
18 | zh-normalization
19 | transformers
20 | peft
21 | accelerate
22 | sentencepiece
23 | omegaconf
24 | soundfile
--------------------------------------------------------------------------------
/scripts/conversions/convert_train_list.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/12/18
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: convert_train_list.py
7 | """
8 | python scripts/conversions/convert_train_list.py \
9 | data/leijun/asr_opt/denoise_opt.list \
10 | data/leijun/asr_opt/denoise_opt_new.list \
11 | leijun
12 | """
13 | import argparse
14 | import pdb
15 |
16 |
17 | def convert_list(input_file, output_file, speaker_name):
18 | with open(input_file, 'r', encoding='utf-8') as infile, open(output_file, 'w', encoding='utf-8') as outfile:
19 | for line in infile:
20 | # 去掉行末的换行符
21 | line = line.strip()
22 | if line:
23 | # 分割行内容
24 | parts = line.split('|')
25 | audio_path = parts[0]
26 | lang = parts[-2]
27 | text = parts[-1]
28 | # 构建新的行,没有空格
29 | new_line = f"{speaker_name}|{audio_path}|{lang}|{text}\n"
30 | outfile.write(new_line)
31 | print(output_file)
32 |
33 |
34 | if __name__ == "__main__":
35 | parser = argparse.ArgumentParser(description='Convert list file format.')
36 | parser.add_argument('input_file', type=str, help='Input list file')
37 | parser.add_argument('output_file', type=str, help='Output list file')
38 | parser.add_argument('speaker_name', type=str, help='Speaker name to add to each line')
39 |
40 | args = parser.parse_args()
41 |
42 | convert_list(args.input_file, args.output_file, args.speaker_name)
43 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/8/27 22:47
3 | # @Project : ChatTTSPlus
4 | # @FileName: setup.py
5 |
6 | import os
7 | from setuptools import setup, find_packages
8 |
9 | version = "v1.0.0"
10 |
11 | setup(
12 | name="chattts_plus",
13 | version=os.environ.get("CHATTTS_PLUS_VER", version).lstrip("v"),
14 | description="",
15 | long_description=open("README.md", encoding="utf8").read(),
16 | long_description_content_type="text/markdown",
17 | author="wenshao",
18 | author_email="wenshaoguo1026@gmail.com",
19 | url="https://github.com/warmshao/ChatTTSPlus",
20 | packages=[
21 | 'chattts_plus',
22 | 'chattts_plus.models',
23 | 'chattts_plus.pipelines',
24 | 'chattts_plus.commons',
25 | ],
26 | license="AGPLv3+",
27 | install_requires=[
28 | "numba",
29 | "numpy<2.0.0",
30 | "pybase16384",
31 | "torch>=2.1.0",
32 | "torchaudio",
33 | "tqdm",
34 | "transformers>=4.41.1",
35 | "vector_quantize_pytorch",
36 | "vocos"
37 | ]
38 | )
39 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/9/22 15:53
3 | # @Project : ChatTTSPlus
4 | # @FileName: test_models.py
5 | import os
6 | import pdb
7 | import torch
8 | import torchaudio
9 |
10 |
11 | def test_tokenizer():
12 | from chattts_plus.models.tokenizer import Tokenizer
13 | model_path = "checkpoints/asset/tokenizer.pt"
14 | if torch.cuda.is_available():
15 | device = torch.device("cuda")
16 | weight_type = torch.float16
17 | else:
18 | device = torch.device("cpu")
19 | weight_type = torch.float32
20 | tokenizer_ = Tokenizer(model_path)
21 |
22 | text = "hello world!"
23 | input_ids, attention_mask, text_mask = tokenizer_.encode([text], 4, None, device)
24 |
25 |
26 | def test_dvae_encode():
27 | from chattts_plus.models.dvae import DVAE
28 | if torch.cuda.is_available():
29 | device = torch.device("cuda")
30 | weight_type = torch.float16
31 | else:
32 | device = torch.device("cpu")
33 | weight_type = torch.float32
34 |
35 | audio_file = "data/xionger/slicer_opt/vocal_5.WAV_10.wav_0000251200_0000423680.wav"
36 | audio_wav, audio_sr_ = torchaudio.load(audio_file)
37 | audio_sr = 24000
38 | audio_wav = torchaudio.functional.resample(audio_wav, orig_freq=audio_sr_, new_freq=audio_sr)
39 | audio_wav = torch.mean(audio_wav, 0).to(device, dtype=weight_type)
40 |
41 | decoder_config = dict(
42 | idim=512,
43 | odim=512,
44 | hidden=256,
45 | n_layer=12,
46 | bn_dim=128,
47 | )
48 | encoder_config = dict(
49 | idim=512,
50 | odim=1024,
51 | hidden=256,
52 | n_layer=12,
53 | bn_dim=128,
54 | )
55 | vq_config = dict(
56 | dim=1024,
57 | levels=(5, 5, 5, 5),
58 | G=2,
59 | R=2,
60 | )
61 | model_path = "checkpoints/asset/DVAE_full.pt"
62 | dvae_encoder = DVAE(
63 | decoder_config=decoder_config,
64 | encoder_config=encoder_config,
65 | vq_config=vq_config,
66 | dim=decoder_config["idim"],
67 | model_path=model_path,
68 | coef=None,
69 | )
70 | dvae_encoder = dvae_encoder.eval().to(device, dtype=weight_type)
71 | audio_ids = dvae_encoder(audio_wav, "encode")
72 | pdb.set_trace()
73 |
74 |
75 | def test_dvae_decode():
76 | from chattts_plus.models.dvae import DVAE
77 | if torch.cuda.is_available():
78 | device = torch.device("cuda")
79 | weight_type = torch.float16
80 | else:
81 | device = torch.device("cpu")
82 | weight_type = torch.float32
83 | decoder_config = dict(
84 | idim=384,
85 | odim=384,
86 | hidden=512,
87 | n_layer=12,
88 | bn_dim=128
89 | )
90 | model_path = "checkpoints/asset/Decoder.pt"
91 | dvae_decoder = DVAE(
92 | decoder_config=decoder_config,
93 | dim=decoder_config["idim"],
94 | coef=None,
95 | model_path=model_path
96 | )
97 | dvae_decoder = dvae_decoder.eval().to(device, dtype=weight_type)
98 |
99 | vq_feats = torch.randn(1, 768, 388).to(device, dtype=weight_type)
100 | mel_feats = dvae_decoder(vq_feats)
101 | pdb.set_trace()
102 |
103 |
104 | def test_vocos():
105 | import vocos
106 | import vocos.feature_extractors
107 | import vocos.models
108 | import vocos.heads
109 |
110 | feature_extractor_cfg = dict(
111 | sample_rate=24000,
112 | n_fft=1024,
113 | hop_length=256,
114 | n_mels=100,
115 | padding="center",
116 | )
117 | backbone_cfg = dict(
118 | input_channels=100,
119 | dim=512,
120 | intermediate_dim=1536,
121 | num_layers=8
122 | )
123 | head_cfg = dict(
124 | dim=512,
125 | n_fft=1024,
126 | hop_length=256,
127 | padding="center"
128 | )
129 | feature_extractor = vocos.feature_extractors.MelSpectrogramFeatures(**feature_extractor_cfg)
130 | backbone = vocos.models.VocosBackbone(**backbone_cfg)
131 | head = vocos.heads.ISTFTHead(**head_cfg)
132 |
133 | device = torch.device("cuda")
134 | dtype = torch.float16
135 | if "mps" in str(device):
136 | device = torch.device("cpu")
137 | dtype = torch.float32
138 |
139 | vocos = vocos.Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head).to(
140 | device, dtype=dtype).eval()
141 | vocos_ckpt_path = "checkpoints/asset/Vocos.pt"
142 | vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
143 |
144 | mel_feats = torch.randn(1, 100, 388 * 2).to(device, dtype=dtype)
145 | audio_wavs = vocos.decode(mel_feats).cpu().float()
146 | result_dir = "./results/test_vocos"
147 | os.makedirs(result_dir, exist_ok=True)
148 | torchaudio.save(os.path.join(result_dir, "test.wav"), audio_wavs, sample_rate=24000)
149 |
150 |
151 | def test_gpt():
152 | from chattts_plus.models.gpt import GPT
153 |
154 | if torch.cuda.is_available():
155 | device = torch.device("cuda")
156 | weight_type = torch.float16
157 | else:
158 | device = torch.device("cpu")
159 | weight_type = torch.float32
160 | model_path = "checkpoints/asset/GPT.pt"
161 | gpt_cfg = dict(
162 | hidden_size=768,
163 | intermediate_size=3072,
164 | num_attention_heads=12,
165 | num_hidden_layers=20,
166 | use_cache=False,
167 | max_position_embeddings=4096,
168 | spk_emb_dim=192,
169 | spk_KL=False,
170 | num_audio_tokens=626,
171 | num_vq=4,
172 | )
173 | gpt = GPT(gpt_cfg, model_path=model_path)
174 |
175 | gpt = gpt.eval().to(device, dtype=weight_type)
176 | pdb.set_trace()
177 |
178 |
179 | if __name__ == '__main__':
180 | # test_tokenizer()
181 | test_dvae_encode()
182 | # test_dvae_decode()
183 | # test_vocos()
184 | # test_gpt()
185 |
--------------------------------------------------------------------------------
/tests/test_pipelines.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/9/22 12:48
3 | # @Project : ChatTTSPlus
4 | # @FileName: test_pipelines.py
5 | import os
6 | import pdb
7 |
8 |
9 | def test_chattts_plus_pipeline():
10 | import torch
11 | import time
12 | import torchaudio
13 |
14 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
15 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams
16 | from omegaconf import OmegaConf
17 |
18 | infer_cfg_path = "configs/infer/chattts_plus.yaml"
19 | infer_cfg = OmegaConf.load(infer_cfg_path)
20 |
21 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda"))
22 |
23 | params_infer_code = InferCodeParams(
24 | prompt="[speed_3]",
25 | temperature=.0003,
26 | max_new_token=2048,
27 | top_P=0.7,
28 | top_K=20
29 | )
30 | params_refine_text = RefineTextParams(
31 | prompt='[oral_2][laugh_3][break_4]',
32 | top_P=0.7,
33 | top_K=20,
34 | temperature=0.7,
35 | max_new_token=384
36 | )
37 | infer_text = ["我们针对对话式任务进行了优化,能够实现自然且富有表现力的合成语音"]
38 | t0 = time.time()
39 | # leijun: outputs/leijun_lora-1732802535.8597126/checkpoints/step-2000
40 | # xionger: outputs/xionger_lora-1732809910.2932503/checkpoints/step-600
41 | lora_path = "outputs/leijun_lora-1734532984.1128285/checkpoints/step-2000"
42 | pipe_res_gen = pipeline.infer(
43 | infer_text,
44 | params_refine_text=params_refine_text,
45 | params_infer_code=params_infer_code,
46 | use_decoder=True,
47 | stream=False,
48 | skip_refine_text=True,
49 | do_text_normalization=True,
50 | do_homophone_replacement=True,
51 | do_text_optimization=True,
52 | lora_path=lora_path,
53 | speaker_emb_path=''
54 | )
55 | wavs = []
56 | for wavs_ in pipe_res_gen:
57 | wavs.extend(wavs_)
58 | print("total infer time:{} sec".format(time.time() - t0))
59 | save_dir = "results/chattts_plus"
60 | os.makedirs(save_dir, exist_ok=True)
61 | audio_save_path = f"{save_dir}/{os.path.basename(lora_path)}-{time.time()}.wav"
62 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000)
63 | print(audio_save_path)
64 |
65 |
66 | def test_chattts_plus_trt_pipeline():
67 | import torch
68 | import time
69 | import torchaudio
70 |
71 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
72 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams
73 | from omegaconf import OmegaConf
74 |
75 | infer_cfg_path = "configs/infer/chattts_plus_trt.yaml"
76 | infer_cfg = OmegaConf.load(infer_cfg_path)
77 |
78 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda"))
79 |
80 | speaker_emb_path = "assets/speakers/2222.pt"
81 | params_infer_code = InferCodeParams(
82 | prompt="[speed_5]",
83 | temperature=.0003,
84 | max_new_token=2048,
85 | top_P=0.7,
86 | top_K=20
87 | )
88 | params_refine_text = RefineTextParams(
89 | prompt='[oral_2][laugh_0][break_4]',
90 | top_P=0.7,
91 | top_K=20,
92 | temperature=0.3,
93 | max_new_token=384
94 | )
95 | infer_text = [
96 | "一场雨后,天空和地面互换了身份,抬头万里暗淡,足下星河生辉。这句话真是绝了.你觉得呢.哈哈哈哈",
97 | "本邮件内容是根据招商银行客户提供的个人邮箱发送给其本人的电子邮件,如您并非抬头标明的收件人,请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!"
98 | ]
99 | t0 = time.time()
100 | pipe_res_gen = pipeline.infer(
101 | infer_text,
102 | params_refine_text=params_refine_text,
103 | params_infer_code=params_infer_code,
104 | use_decoder=True,
105 | stream=False,
106 | skip_refine_text=False,
107 | do_text_normalization=True,
108 | do_homophone_replacement=True,
109 | do_text_optimization=True,
110 | speaker_emb_path=speaker_emb_path
111 | )
112 | wavs = []
113 | for wavs_ in pipe_res_gen:
114 | wavs.extend(wavs_)
115 | print("total infer time:{} sec".format(time.time() - t0))
116 | save_dir = "results/chattts_plus"
117 | os.makedirs(save_dir, exist_ok=True)
118 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_emb_path)}-{time.time()}.wav"
119 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000)
120 | print(audio_save_path)
121 |
122 |
123 | def test_chattts_plus_zero_shot_pipeline():
124 | import torch
125 | import time
126 | import torchaudio
127 |
128 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
129 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams
130 | from omegaconf import OmegaConf
131 |
132 | infer_cfg_path = "configs/infer/chattts_plus.yaml"
133 | infer_cfg = OmegaConf.load(infer_cfg_path)
134 |
135 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda"))
136 |
137 | params_infer_code = InferCodeParams(
138 | prompt="[speed_5]",
139 | temperature=.0003,
140 | max_new_token=2048,
141 | top_P=0.7,
142 | top_K=20
143 | )
144 | params_refine_text = RefineTextParams(
145 | prompt='[oral_2][laugh_0][break_4]',
146 | top_P=0.7,
147 | top_K=20,
148 | temperature=0.3,
149 | max_new_token=384
150 | )
151 | infer_text = [
152 | "一场雨后,天空和地面互换了身份,抬头万里暗淡,足下星河生辉。这句话真是绝了.你觉得呢.哈哈哈哈",
153 | "本邮件内容是根据招商银行客户提供的个人邮箱发送给其本人的电子邮件,如您并非抬头标明的收件人,请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!"
154 | ]
155 | speaker_audio_path = "data/xionger/slicer_opt/vocal_1.WAV_10.wav_0000000000_0000152640.wav"
156 | speaker_audio_text = "嘿嘿,最近我看了寄生虫,真的很推荐哦。"
157 | t0 = time.time()
158 | pipe_res_gen = pipeline.infer(
159 | infer_text,
160 | params_refine_text=params_refine_text,
161 | params_infer_code=params_infer_code,
162 | use_decoder=True,
163 | stream=False,
164 | skip_refine_text=False,
165 | do_text_normalization=True,
166 | do_homophone_replacement=True,
167 | do_text_optimization=True,
168 | speaker_emb_path=None,
169 | speaker_audio_path=speaker_audio_path,
170 | speaker_audio_text=speaker_audio_text
171 | )
172 | wavs = []
173 | for wavs_ in pipe_res_gen:
174 | wavs.extend(wavs_)
175 | print("total infer time:{} sec".format(time.time() - t0))
176 | save_dir = "results/chattts_plus"
177 | os.makedirs(save_dir, exist_ok=True)
178 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_audio_path)}-{time.time()}.wav"
179 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000)
180 | print(audio_save_path)
181 |
182 |
183 | def test_chattts_plus_zero_shot_trt_pipeline():
184 | import torch
185 | import time
186 | import torchaudio
187 |
188 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
189 | from chattts_plus.commons.utils import InferCodeParams, RefineTextParams
190 | from omegaconf import OmegaConf
191 |
192 | infer_cfg_path = "configs/infer/chattts_plus_trt.yaml"
193 | infer_cfg = OmegaConf.load(infer_cfg_path)
194 |
195 | pipeline = ChatTTSPlusPipeline(infer_cfg, device=torch.device("cuda"))
196 |
197 | params_infer_code = InferCodeParams(
198 | prompt="[speed_5]",
199 | temperature=.0003,
200 | max_new_token=2048,
201 | top_P=0.7,
202 | top_K=20
203 | )
204 | params_refine_text = RefineTextParams(
205 | prompt='[oral_2][laugh_0][break_4]',
206 | top_P=0.7,
207 | top_K=20,
208 | temperature=0.3,
209 | max_new_token=384
210 | )
211 | infer_text = [
212 | "请您即刻删除本邮件,勿以任何形式使用及传播本邮件内容,谢谢!"
213 | ]
214 | speaker_audio_path = "data/xionger/slicer_opt/vocal_1.WAV_10.wav_0000000000_0000152640.wav"
215 | speaker_audio_text = "嘿嘿,最近我看了寄生虫,真的很推荐哦。"
216 | t0 = time.time()
217 | pipe_res_gen = pipeline.infer(
218 | infer_text,
219 | params_refine_text=params_refine_text,
220 | params_infer_code=params_infer_code,
221 | use_decoder=True,
222 | stream=False,
223 | skip_refine_text=False,
224 | do_text_normalization=True,
225 | do_homophone_replacement=True,
226 | do_text_optimization=True,
227 | speaker_emb_path=None,
228 | speaker_audio_path=speaker_audio_path,
229 | speaker_audio_text=speaker_audio_text
230 | )
231 | wavs = []
232 | for wavs_ in pipe_res_gen:
233 | wavs.extend(wavs_)
234 | print("total infer time:{} sec".format(time.time() - t0))
235 | save_dir = "results/chattts_plus"
236 | os.makedirs(save_dir, exist_ok=True)
237 | audio_save_path = f"{save_dir}/{os.path.basename(speaker_audio_path)}-{time.time()}.wav"
238 | torchaudio.save(audio_save_path, torch.cat(wavs).cpu().float().unsqueeze(0), 24000)
239 | print(audio_save_path)
240 |
241 |
242 | if __name__ == '__main__':
243 | test_chattts_plus_pipeline()
244 | # test_chattts_plus_trt_pipeline()
245 | # test_chattts_plus_zero_shot_pipeline()
246 | # test_chattts_plus_zero_shot_trt_pipeline()
247 |
--------------------------------------------------------------------------------
/train_lora.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2024/11/23
3 | # @Author : wenshao
4 | # @Email : wenshaoguo1026@gmail.com
5 | # @Project : ChatTTSPlus
6 | # @FileName: train_lora.py
7 | """
8 | accelerate launch train_lora.py --config configs/train/train_voice_clone_lora.yaml
9 | """
10 | import logging
11 | import math
12 | import os.path
13 | import pdb
14 | import numpy as np
15 | from datetime import timedelta
16 | from tqdm import tqdm
17 | import peft
18 | import pickle
19 | from accelerate import InitProcessGroupKwargs
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | import torch.utils.checkpoint
24 | import transformers
25 | from accelerate import Accelerator
26 | from accelerate.utils import DistributedDataParallelKwargs
27 | from omegaconf import OmegaConf
28 | import warnings
29 | from peft import LoraConfig, get_peft_model
30 | import time
31 | import pybase16384 as b14
32 | from huggingface_hub import hf_hub_download
33 | from transformers.trainer_pt_utils import LabelSmoother
34 | from vector_quantize_pytorch.residual_fsq import GroupedResidualFSQ
35 | from einops import rearrange
36 | from peft import PeftConfig, PeftModel
37 | from chattts_plus.commons.logger import get_logger
38 | from chattts_plus.commons import constants
39 | from chattts_plus.models.tokenizer import Tokenizer
40 | from chattts_plus.models.dvae import DVAE
41 | from chattts_plus.models.gpt import GPT
42 | from chattts_plus.datasets.base_dataset import BaseDataset
43 | from chattts_plus.datasets.collator import BaseCollator
44 | from chattts_plus.commons import norm
45 |
46 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index
47 | AUDIO_PAD_TOKEN_ID: int = 0
48 |
49 | warnings.filterwarnings("ignore")
50 |
51 |
52 | def get_mel_attention_mask(
53 | waveform_attention_mask: torch.Tensor, # (batch_size, time)
54 | mel_len: int,
55 | ):
56 | batch_size = waveform_attention_mask.size(0)
57 | mel_attention_mask = torch.ones(
58 | (batch_size, mel_len),
59 | device=waveform_attention_mask.device,
60 | )
61 | indices = waveform_attention_mask.int().sum(dim=1) # (batch_size,)
62 | indices = torch.ceil(indices * mel_len / waveform_attention_mask.size(1)).int()
63 | for i in range(batch_size):
64 | mel_attention_mask[i, indices[i]:] = 0
65 | return mel_attention_mask # (batch_size, mel_len)
66 |
67 |
68 | def main(cfg):
69 | output_dir = os.path.join(cfg.output_dir, f"{cfg.exp_name}-{time.time()}")
70 | os.makedirs(output_dir, exist_ok=True)
71 | checkpoints_dir = os.path.join(output_dir, "checkpoints")
72 | os.makedirs(checkpoints_dir, exist_ok=True)
73 | log_dir = os.path.join(output_dir, "logs")
74 | logger = get_logger("Lora Training", log_file=os.path.join(log_dir, "train.log"))
75 | logger.info(cfg)
76 | kwargs_handlers = [DistributedDataParallelKwargs(find_unused_parameters=False)]
77 | accelerator = Accelerator(
78 | gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
79 | mixed_precision=cfg.solver.mixed_precision,
80 | kwargs_handlers=kwargs_handlers,
81 | )
82 | if accelerator.is_local_main_process:
83 | from torch.utils.tensorboard import SummaryWriter
84 | tf_writer = SummaryWriter(log_dir=os.path.join(output_dir, "tf_logs"))
85 |
86 | # load model
87 | if cfg.weight_dtype == "fp16":
88 | weight_dtype = torch.float16
89 | elif cfg.weight_dtype == "fp32":
90 | weight_dtype = torch.float32
91 | else:
92 | raise ValueError(
93 | f"Do not support weight dtype: {cfg.weight_dtype} during training"
94 | )
95 | logger.info(f"weight_dtype: {str(weight_dtype)}")
96 | logger.info("loading tokenizer >>>")
97 | tokenizer_kwargs = cfg.MODELS["tokenizer"]["kwargs"]
98 | model_path_org = tokenizer_kwargs["model_path"]
99 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", ""))
100 | if not os.path.exists(model_path_new):
101 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace")
102 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset",
103 | filename=os.path.basename(model_path_new),
104 | local_dir=constants.CHECKPOINT_DIR)
105 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS")
106 | tokenizer_kwargs["model_path"] = model_path_new
107 | tokenizer = Tokenizer(**tokenizer_kwargs)
108 |
109 | # load DVAE encoder
110 | logger.info("loading DVAE encode >>>")
111 | dvae_kwargs = cfg.MODELS["dvae_encode"]["kwargs"]
112 | if not dvae_kwargs["coef"]:
113 | coef_ = torch.rand(100)
114 | coef = b14.encode_to_string(coef_.numpy().astype(np.float32).tobytes())
115 | dvae_kwargs["coef"] = coef
116 | logger.info(f"Set DAVE Encode Coef: {dvae_kwargs['coef']}")
117 | else:
118 | coef = dvae_kwargs["coef"]
119 | model_path_org = dvae_kwargs["model_path"]
120 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", ""))
121 | if not os.path.exists(model_path_new):
122 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace")
123 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset",
124 | filename=os.path.basename(model_path_new),
125 | local_dir=constants.CHECKPOINT_DIR)
126 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS")
127 | dvae_kwargs["model_path"] = model_path_new
128 | dvae_encoder = DVAE(**dvae_kwargs)
129 | dvae_encoder.eval().to(accelerator.device, dtype=weight_dtype)
130 | dvae_encoder.requires_grad_(False)
131 |
132 | # load DVAE decoder
133 | # logger.info("loading DVAE decode >>>")
134 | # dvae_kwargs = cfg.MODELS["dvae_decode"]["kwargs"]
135 | # if not dvae_kwargs["coef"]:
136 | # dvae_kwargs["coef"] = coef
137 | # logger.info(f"Set DAVE Decode Coef: {dvae_kwargs['coef']}")
138 | # model_path_org = dvae_kwargs["model_path"]
139 | # model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", ""))
140 | # if not os.path.exists(model_path_new):
141 | # logger.info(f"{model_path_new} not exists! Need to download from HuggingFace")
142 | # hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset",
143 | # filename=os.path.basename(model_path_new),
144 | # local_dir=constants.CHECKPOINT_DIR)
145 | # logger.info(f"download {model_path_new} from 2Noise/ChatTTS")
146 | # dvae_kwargs["model_path"] = model_path_new
147 | # dvae_decoder = DVAE(**dvae_kwargs)
148 | # dvae_decoder.eval().to(accelerator.device, dtype=weight_dtype)
149 | # dvae_decoder.requires_grad_(False)
150 |
151 | # Load GPT
152 | logger.info("loading GPT model >>>")
153 | gpt_kwargs = cfg.MODELS["gpt"]["kwargs"]
154 | model_path_org = gpt_kwargs["model_path"]
155 | model_path_new = os.path.join(constants.CHECKPOINT_DIR, model_path_org.replace("checkpoints/", ""))
156 | if not os.path.exists(model_path_new):
157 | logger.info(f"{model_path_new} not exists! Need to download from HuggingFace")
158 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset",
159 | filename=os.path.basename(model_path_new),
160 | local_dir=constants.CHECKPOINT_DIR)
161 | logger.info(f"download {model_path_new} from 2Noise/ChatTTS")
162 | gpt_kwargs["model_path"] = model_path_new
163 | gpt_model = GPT(**gpt_kwargs)
164 | gpt_model.to(accelerator.device, dtype=weight_dtype)
165 | gpt_model.requires_grad_(False)
166 |
167 | # speaker embedding
168 | spk_stat_path = os.path.join(constants.CHECKPOINT_DIR, "asset/spk_stat.pt")
169 | if not os.path.exists(spk_stat_path):
170 | logger.warning(f"{spk_stat_path} not exists! Need to download from HuggingFace")
171 | hf_hub_download(repo_id="2Noise/ChatTTS", subfolder="asset",
172 | filename=os.path.basename(spk_stat_path),
173 | local_dir=constants.CHECKPOINT_DIR)
174 | logger.info(f"download {spk_stat_path} from 2Noise/ChatTTS")
175 | logger.info(f"loading speaker stat: {spk_stat_path}")
176 | assert os.path.exists(spk_stat_path), f"Missing spk_stat.pt: {spk_stat_path}"
177 | spk_stat: torch.Tensor = torch.load(
178 | spk_stat_path,
179 | weights_only=True,
180 | mmap=True
181 | ).to(accelerator.device, dtype=weight_dtype)
182 | speaker_std, speaker_mean = spk_stat.chunk(2)
183 |
184 | # dataset
185 | normalizer_json = os.path.join(constants.CHECKPOINT_DIR, "homophones_map.json")
186 | if not os.path.exists(normalizer_json):
187 | logger.warning(f"{normalizer_json} not exists! Need to download from HuggingFace")
188 | hf_hub_download(repo_id="warmshao/ChatTTSPlus",
189 | filename=os.path.basename(normalizer_json),
190 | local_dir=constants.CHECKPOINT_DIR)
191 | logger.info(f"download {normalizer_json} from warmshao/ChatTTSPlus")
192 | logger.info(f"loading normalizer: {normalizer_json}")
193 | normalizer = norm.Normalizer(normalizer_json)
194 | train_dataset = BaseDataset(
195 | meta_infos=cfg.DATA.meta_infos,
196 | sample_rate=cfg.DATA.sample_rate,
197 | num_vq=cfg.DATA.num_vq,
198 | tokenizer=tokenizer,
199 | normalizer=normalizer,
200 | use_empty_speaker=cfg.use_empty_speaker
201 | )
202 |
203 | # Lora
204 | if cfg.use_empty_speaker:
205 | logger.info("Setting Lora model >>>")
206 | lora_cfg = OmegaConf.to_container(cfg.LORA, resolve=True)
207 | lora_config = LoraConfig(
208 | r=lora_cfg['lora_r'],
209 | lora_alpha=lora_cfg['lora_alpha'],
210 | target_modules=lora_cfg['lora_target_modules'],
211 | lora_dropout=lora_cfg['lora_dropout']
212 | )
213 | peft_model = get_peft_model(gpt_model.gpt, lora_config)
214 | peft_model.config.use_cache = False
215 | if cfg.lora_model_path and os.path.exists(cfg.lora_model_path):
216 | logger.info(f"loading lora weight: {cfg.lora_model_path} >>>")
217 | state_dict = None
218 | if cfg.lora_model_path.endswith(".safetensors"):
219 | from safetensors.torch import load_file
220 | state_dict = load_file("model.safetensors")
221 | elif cfg.lora_model_path.endswith(".pth") or cfg.lora_model_path.endswith(".pt"):
222 | state_dict = torch.load(cfg.lora_model_path)
223 | elif os.path.isdir(cfg.lora_model_path):
224 | state_dict = peft.load_peft_weights(cfg.lora_model_path)
225 | else:
226 | logger.error(f"cannot load {cfg.lora_model_path}")
227 | if state_dict is not None:
228 | peft.set_peft_model_state_dict(peft_model, state_dict)
229 | gpt_model.gpt = peft_model
230 | else:
231 | if cfg.speaker_embeds_path:
232 | with open(cfg.speaker_embeds_path, "rb") as fin:
233 | speaker_embeds = pickle.load(fin)
234 | for speaker in speaker_embeds:
235 | spk_emb = torch.from_numpy(speaker_embeds[speaker]).to(accelerator.device, dtype=torch.float32)
236 | spk_emb = torch.nn.Parameter(spk_emb)
237 | spk_emb.requires_grad_(True)
238 | speaker_embeds[speaker] = spk_emb
239 | else:
240 | speaker_embeds = dict()
241 | for speaker in train_dataset.speakers:
242 | if speaker not in speaker_embeds:
243 | dim: int = speaker_std.shape[-1]
244 | spk_emb = torch.randn(dim, device=speaker_std.device, dtype=torch.float32)
245 | spk_emb = torch.nn.Parameter(spk_emb)
246 | spk_emb.requires_grad_(True)
247 | speaker_embeds[speaker] = spk_emb
248 | train_dataloader = torch.utils.data.DataLoader(
249 | train_dataset, batch_size=cfg.DATA.train_bs, shuffle=True,
250 | num_workers=min(cfg.DATA.train_bs, 4), drop_last=True, collate_fn=BaseCollator()
251 | )
252 |
253 | if cfg.solver.scale_lr:
254 | learning_rate = (
255 | cfg.solver.learning_rate
256 | * cfg.solver.gradient_accumulation_steps
257 | * cfg.data.train_bs
258 | * accelerator.num_processes
259 | )
260 | else:
261 | learning_rate = cfg.solver.learning_rate
262 |
263 | # Initialize the optimizer
264 | if cfg.solver.use_8bit_adam:
265 | try:
266 | import bitsandbytes as bnb
267 | except ImportError:
268 | raise ImportError(
269 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
270 | )
271 |
272 | optimizer_cls = bnb.optim.AdamW8bit
273 | else:
274 | optimizer_cls = torch.optim.AdamW
275 |
276 | trainable_params = list(filter(lambda p: p.requires_grad, gpt_model.gpt.parameters()))
277 | logger.info(f"Total trainable params {len(trainable_params)}")
278 | if cfg.use_empty_speaker:
279 | optimizer = optimizer_cls(
280 | trainable_params,
281 | lr=learning_rate,
282 | betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
283 | weight_decay=cfg.solver.adam_weight_decay
284 | )
285 | else:
286 | optimizer = torch.optim.SGD(
287 | list(speaker_embeds.values()),
288 | lr=learning_rate
289 | )
290 |
291 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
292 | num_update_steps_per_epoch = math.ceil(
293 | len(train_dataloader) / cfg.solver.gradient_accumulation_steps
294 | )
295 | # Afterwards we recalculate our number of training epochs
296 | num_train_epochs = math.ceil(
297 | cfg.solver.max_train_steps / num_update_steps_per_epoch
298 | )
299 | # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
300 | # optimizer, num_train_epochs, cfg.solver.min_learning_rate
301 | # )
302 | lr_scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer)
303 |
304 | # Prepare everything with our `accelerator`.
305 | (
306 | gpt_model,
307 | optimizer,
308 | train_dataloader,
309 | lr_scheduler,
310 | ) = accelerator.prepare(
311 | gpt_model,
312 | optimizer,
313 | train_dataloader,
314 | lr_scheduler,
315 | )
316 |
317 | # Train!
318 | total_batch_size = (
319 | cfg.DATA.train_bs
320 | * accelerator.num_processes
321 | * cfg.solver.gradient_accumulation_steps
322 | )
323 |
324 | logger.info("***** Running training *****")
325 | logger.info(f" Num examples = {len(train_dataset)}")
326 | logger.info(f" Num Epochs = {num_train_epochs}")
327 | logger.info(f" Instantaneous batch size per device = {cfg.DATA.train_bs}")
328 | logger.info(
329 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
330 | )
331 | logger.info(
332 | f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
333 | )
334 | logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
335 | global_step = 0
336 | first_epoch = 0
337 |
338 | # Only show the progress bar once on each machine.
339 | progress_bar = tqdm(
340 | range(global_step, cfg.solver.max_train_steps),
341 | disable=not accelerator.is_local_main_process,
342 | )
343 | progress_bar.set_description("Steps")
344 |
345 | for epoch in range(first_epoch, num_train_epochs):
346 | train_loss = 0.0
347 | for step, batch in enumerate(train_dataloader):
348 | with accelerator.accumulate(gpt_model):
349 | text_input_ids: torch.Tensor = batch[
350 | "text_input_ids"
351 | ] # (batch_size, text_len, num_vq)
352 | text_attention_mask: torch.Tensor = batch[
353 | "text_mask"
354 | ] # (batch_size, text_len)
355 | audio_wavs: torch.Tensor = batch["audio_wavs"] # (batch_size, time)
356 | audio_wavs_mask: torch.Tensor = batch["audio_mask"] # (batch_size, time)
357 |
358 | batch_size = text_input_ids.size(0)
359 | text_input_ids = text_input_ids.to(accelerator.device, non_blocking=True)
360 | text_attention_mask = text_attention_mask.to(accelerator.device, dtype=weight_dtype, non_blocking=True)
361 | audio_wavs = audio_wavs.to(accelerator.device, dtype=weight_dtype, non_blocking=True)
362 | audio_wavs_mask = audio_wavs_mask.to(accelerator.device, dtype=weight_dtype, non_blocking=True)
363 | with torch.no_grad():
364 | # mel_specs = dvae_encoder.preprocessor_mel(audio_wavs)
365 | dvae_audio_input_ids = dvae_encoder(audio_wavs, mode="encode").permute(0, 2, 1).clone()
366 | mel_attention_mask = get_mel_attention_mask(audio_wavs_mask,
367 | mel_len=dvae_audio_input_ids.size(1) * 2)
368 | # if mel_attention_mask.shape[1] > mel_specs.shape[2]:
369 | # mel_attention_mask = mel_attention_mask[:, :mel_specs.shape[2]]
370 | # else:
371 | # mel_specs = mel_specs[:, :, :mel_attention_mask.shape[1]]
372 | # mel_specs = mel_specs * mel_attention_mask.unsqueeze(1)
373 | audio_attention_mask = mel_attention_mask[:, ::2]
374 | dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID
375 |
376 | # add audio eos token
377 | extended_audio_attention_mask = torch.cat(
378 | [
379 | audio_attention_mask,
380 | torch.zeros(
381 | (batch_size, 1),
382 | dtype=audio_attention_mask.dtype,
383 | device=audio_attention_mask.device,
384 | ),
385 | ],
386 | dim=1,
387 | ) # (batch_size, mel_len+1)
388 | extended_audio_input_ids = torch.cat(
389 | [
390 | dvae_audio_input_ids,
391 | AUDIO_PAD_TOKEN_ID
392 | * torch.ones(
393 | (batch_size, 1, gpt_model.num_vq),
394 | dtype=dvae_audio_input_ids.dtype,
395 | device=dvae_audio_input_ids.device,
396 | ),
397 | ],
398 | dim=1,
399 | ) # (batch_size, mel_len+1, num_vq)
400 | indices = audio_attention_mask.int().sum(dim=1) # (batch_size,)
401 | AUDIO_EOS_TOKEN_ID = int(gpt_model.emb_code[0].num_embeddings - 1)
402 | for i in range(batch_size):
403 | extended_audio_attention_mask[i, indices[i]] = 1
404 | extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID
405 |
406 | # combine text and audio
407 | input_ids = torch.cat( # (batch_size, text_len + mel_len + 1, num_vq)
408 | [
409 | text_input_ids,
410 | extended_audio_input_ids, # (batch_size, mel_len, num_vq)
411 | ],
412 | dim=1,
413 | )
414 | attention_mask = torch.cat( # (batch_size, text_len + mel_len + 1)
415 | [text_attention_mask, extended_audio_attention_mask],
416 | dim=1,
417 | )
418 | text_mask = torch.cat( # (batch_size, text_len + mel_len + 1)
419 | [
420 | torch.ones_like(text_attention_mask, dtype=bool),
421 | torch.zeros_like(extended_audio_attention_mask, dtype=bool),
422 | ],
423 | dim=1,
424 | )
425 |
426 | # set labels
427 | labels = input_ids.detach().clone() # (batch_size, text_len + mel_len + 1, num_vq)
428 | labels[~attention_mask.bool()] = IGNORE_TOKEN_ID
429 | # (batch_size, text_len + mel_len, 768)
430 | inputs_embeds = gpt_model.forward(input_ids=input_ids, text_mask=text_mask)
431 | text_len = text_input_ids.size(1)
432 | if not cfg.use_empty_speaker:
433 | for i, speaker in enumerate(batch['speaker']):
434 | spk_emb = speaker_embeds[speaker].mul(speaker_std.detach()).add(speaker_mean.detach())
435 | spk_emb = F.normalize(spk_emb, p=2.0, dim=0, eps=1e-12).unsqueeze_(0)
436 | cond = text_input_ids[i].narrow(-1, 0, 1).eq(tokenizer.spk_emb_ids)
437 | inputs_embeds[i, :text_len] = torch.where(cond, spk_emb.to(inputs_embeds.dtype),
438 | inputs_embeds[i, :text_len])
439 | outputs = gpt_model.gpt.forward(inputs_embeds=inputs_embeds.to(dtype=weight_dtype),
440 | attention_mask=attention_mask.to(dtype=weight_dtype))
441 | hidden_states = outputs.last_hidden_state
442 |
443 | audio_hidden_states = hidden_states[
444 | :, text_len - 1: -1
445 | ] # (batch_size, mel_len+1, 768)
446 | audio_labels = labels[:, text_len:]
447 | audio_logits = torch.stack(
448 | [
449 | gpt_model.head_code[i](audio_hidden_states)
450 | for i in range(gpt_model.num_vq)
451 | ],
452 | dim=2,
453 | ) # (batch_size, mel_len+1, num_vq, num_class_audio)
454 |
455 | audio_loss = torch.nn.functional.cross_entropy(
456 | audio_logits.flatten(0, 2), audio_labels.flatten(0, 2), ignore_index=IGNORE_TOKEN_ID
457 | )
458 | loss = audio_loss
459 |
460 | with torch.no_grad():
461 | # Get predictions
462 | predictions = audio_logits.flatten(0, 2).argmax(dim=-1) # (batch_size * mel_len * num_vq)
463 | labels_flat = audio_labels.flatten(0, 2) # (batch_size * mel_len * num_vq)
464 | # Create mask for valid tokens (not IGNORE_TOKEN_ID)
465 | valid_mask = (labels_flat != IGNORE_TOKEN_ID)
466 |
467 | # Calculate accuracy only on valid tokens
468 | correct = (predictions[valid_mask] == labels_flat[valid_mask]).float()
469 | accuracy = correct.mean() if valid_mask.any() else torch.tensor(0.0).to(correct.device)
470 |
471 | # Gather the accuracy across all processes
472 | avg_accuracy = accelerator.gather(accuracy.repeat(cfg.DATA.train_bs)).mean()
473 |
474 | # Gather the losses across all processes for logging (if we use distributed training).
475 | avg_loss = accelerator.gather(loss.repeat(cfg.DATA.train_bs)).mean()
476 | train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
477 | train_accuracy = avg_accuracy.item()
478 |
479 | # Backpropagate
480 | accelerator.backward(loss)
481 | # if accelerator.sync_gradients:
482 | # accelerator.clip_grad_norm_(
483 | # trainable_params,
484 | # cfg.solver.max_grad_norm,
485 | # )
486 | optimizer.step()
487 | lr_scheduler.step()
488 | optimizer.zero_grad()
489 |
490 | if accelerator.sync_gradients:
491 | progress_bar.update(1)
492 | global_step += 1
493 | accelerator.log({"train_loss": train_loss, "train_acc": train_accuracy}, step=global_step)
494 | if accelerator.is_main_process:
495 | tf_writer.add_scalar('train_loss', train_loss, global_step)
496 | tf_writer.add_scalar('train_acc', train_accuracy, global_step)
497 | tf_writer.add_scalar('train_audio_loss', audio_loss.detach().item(), global_step)
498 | train_loss = 0.0
499 |
500 | if global_step == 1 or global_step % cfg.checkpointing_steps == 0:
501 | if accelerator.is_main_process:
502 | step_checkpoint_dir = os.path.join(checkpoints_dir, f"step-{global_step}")
503 | os.makedirs(step_checkpoint_dir, exist_ok=True)
504 | if not cfg.use_empty_speaker:
505 | for spk_name in speaker_embeds:
506 | spk_emb = speaker_embeds[speaker].detach().mul(speaker_std).add(speaker_mean)
507 | spk_emb = tokenizer._encode_spk_emb(spk_emb)
508 | output_path = os.path.join(step_checkpoint_dir, f"{spk_name}.pt")
509 | torch.save(spk_emb, output_path)
510 |
511 | speaker_embeds_w = {}
512 | for speaker in speaker_embeds:
513 | speaker_embeds_w[speaker] = speaker_embeds[speaker].detach().float().cpu().data.numpy()
514 | with open(os.path.join(step_checkpoint_dir, "speaker_embeds.pkl"), "wb") as fw:
515 | pickle.dump(speaker_embeds_w, fw)
516 | else:
517 | unwrap_net = accelerator.unwrap_model(gpt_model)
518 | unwrap_net.gpt.save_pretrained(step_checkpoint_dir)
519 |
520 | logs = {
521 | "loss": loss.detach().item(),
522 | "audio_loss": audio_loss.detach().item(),
523 | "step_acc": train_accuracy,
524 | "lr": lr_scheduler.get_last_lr()[0]
525 | }
526 | progress_bar.set_postfix(**logs)
527 |
528 | if global_step >= cfg.solver.max_train_steps:
529 | break
530 |
531 | accelerator.wait_for_everyone()
532 | accelerator.end_training()
533 |
534 |
535 | if __name__ == "__main__":
536 | import argparse
537 |
538 | parser = argparse.ArgumentParser()
539 | parser.add_argument("--config", type=str, default="./configs/train/train_voice_clone_lora.yaml")
540 | args = parser.parse_args()
541 | config = OmegaConf.load(args.config)
542 | main(config)
543 |
--------------------------------------------------------------------------------
/update.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | git fetch origin
3 | git reset --hard origin/master
4 | "venv\Scripts\pip.exe" install -r requirements.txt
5 | pause
--------------------------------------------------------------------------------
/webui.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | ".\venv\python.exe" webui.py --cfg .\configs\infer\chattts_plus_trt.yaml --server_port 7890
3 | pause
--------------------------------------------------------------------------------
/webui.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import pdb
3 | import uuid
4 |
5 | if sys.platform == "darwin":
6 | os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
7 | import random
8 | import argparse
9 | import gradio as gr
10 | from omegaconf import OmegaConf
11 | import numpy as np
12 | import math
13 | import torch
14 | import subprocess
15 | import shutil
16 |
17 | from chattts_plus.pipelines.chattts_plus_pipeline import ChatTTSPlusPipeline
18 | from chattts_plus.commons import utils
19 | from chattts_plus.commons import constants
20 |
21 | # ChatTTSPlus pipeline
22 | pipe: ChatTTSPlusPipeline = None
23 |
24 | js_func = """
25 | function refresh() {
26 | const url = new URL(window.location);
27 |
28 | if (url.searchParams.get('__theme') !== 'dark') {
29 | url.searchParams.set('__theme', 'dark');
30 | window.location.href = url.href;
31 | }
32 | }
33 | """
34 |
35 | seed_min = 1
36 | seed_max = 4294967295
37 |
38 |
39 | def generate_seed():
40 | return gr.update(value=random.randint(seed_min, seed_max))
41 |
42 |
43 | def update_spk_emb_path(file):
44 | spk_emb_path = file.name
45 | return spk_emb_path
46 |
47 |
48 | def update_spk_lora_path(files):
49 | spk_lora_path = os.path.join(constants.CHECKPOINT_DIR, "lora", str(uuid.uuid4()))
50 | os.makedirs(spk_lora_path, exist_ok=True)
51 | is_valid = 0
52 | for file in files:
53 | if file.name.endswith(".json"):
54 | shutil.copy(file.name, os.path.join(spk_lora_path, "adapter_config.json"))
55 | is_valid += 1
56 | elif file.name.endswith(".safetensors"):
57 | shutil.copy(file.name, os.path.join(spk_lora_path, "adapter_model.safetensors"))
58 | is_valid += 1
59 | if is_valid == 2:
60 | return spk_lora_path
61 | else:
62 | return ''
63 |
64 |
65 | def list_pt_files_in_dir(directory):
66 | if os.path.isdir(directory):
67 | pt_files = [f for f in os.listdir(directory) if f.endswith('.pt')]
68 | return gr.Dropdown(label="Select Speaker Embedding", choices=pt_files) if pt_files else gr.Dropdown(
69 | label="Select Speaker Embedding", choices=[])
70 | return gr.Dropdown(label="Select Speaker Embedding", choices=[])
71 |
72 |
73 | def set_spk_emb_path_from_dir(directory, selected_file):
74 | spk_emb_path = os.path.join(directory, selected_file)
75 | return spk_emb_path
76 |
77 |
78 | def float_to_int16(audio: np.ndarray) -> np.ndarray:
79 | am = int(math.ceil(float(np.abs(audio).max())) * 32768)
80 | am = 32767 * 32768 // am
81 | return np.multiply(audio, am).astype(np.int16)
82 |
83 |
84 | def refine_text(
85 | text,
86 | prompt,
87 | temperature,
88 | top_P,
89 | top_K,
90 | text_seed_input,
91 | refine_text_flag,
92 | ):
93 | global pipe
94 |
95 | if not refine_text_flag:
96 | return text
97 |
98 | with utils.TorchSeedContext(text_seed_input):
99 | params_refine_text = utils.RefineTextParams(
100 | prompt=prompt,
101 | top_P=top_P,
102 | top_K=top_K,
103 | temperature=temperature,
104 | max_new_token=384
105 | )
106 | text_gen = pipe.infer(
107 | text,
108 | skip_refine_text=False,
109 | refine_text_only=True,
110 | do_text_normalization=True,
111 | do_homophone_replacement=True,
112 | do_text_optimization=True,
113 | params_refine_text=params_refine_text
114 | )
115 | texts = []
116 | for text_ in text_gen:
117 | texts.extend(text_)
118 |
119 | return "\n".join(texts)
120 |
121 |
122 | def generate_audio(
123 | text,
124 | prompt,
125 | temperature,
126 | top_P,
127 | top_K,
128 | spk_emb_path,
129 | stream,
130 | audio_seed_input,
131 | sample_text_input,
132 | sample_audio_input,
133 | spk_lora_path
134 | ):
135 | global pipe
136 |
137 | if not text:
138 | return None
139 |
140 | params_infer_code = utils.InferCodeParams(
141 | prompt=prompt,
142 | temperature=temperature,
143 | top_P=top_P,
144 | top_K=top_K,
145 | max_new_token=2048,
146 | )
147 |
148 | with utils.TorchSeedContext(audio_seed_input):
149 | wav_gen = pipe.infer(
150 | text,
151 | skip_refine_text=True,
152 | do_text_normalization=False,
153 | do_homophone_replacement=False,
154 | do_text_optimization=False,
155 | params_infer_code=params_infer_code,
156 | stream=stream,
157 | speaker_audio_path=sample_audio_input,
158 | speaker_audio_text=sample_text_input,
159 | speaker_emb_path=spk_emb_path,
160 | lora_path=spk_lora_path
161 | )
162 | if stream:
163 | for gen in wav_gen:
164 | audio = gen[0].cpu().float().numpy()
165 | if audio is not None and len(audio) > 0:
166 | yield 24000, float_to_int16(audio).T
167 | else:
168 | wavs = []
169 | for wavs_ in wav_gen:
170 | wavs.extend(wavs_)
171 | wavs = torch.cat(wavs).cpu().float().numpy()
172 | yield 24000, float_to_int16(wavs).T
173 |
174 |
175 | def update_active_tab(tab_name):
176 | return gr.State(tab_name)
177 |
178 |
179 | # 清空 Speaker Embedding Path 的重置函数
180 | def reset_spk_emb_path():
181 | return "", []
182 |
183 |
184 | # 清空 Sample Audio 和 Sample Text 的重置函数
185 | def reset_sample_inputs():
186 | return None, "" # 返回 None 清空音频,空字符串清空文本框
187 |
188 |
189 | def reset_lora_inputs():
190 | return None, "" # 返回 None 清空音频,空字符串清空文本框
191 |
192 |
193 | def main(args):
194 | with gr.Blocks(theme=gr.themes.Soft(font=[gr.themes.GoogleFont("Plus Jakarta Sans")]), js=js_func) as demo:
195 | gr.Markdown("# ChatTTSPlus WebUI")
196 | with gr.Row():
197 | with gr.Column():
198 | text_input = gr.Textbox(
199 | label="Input Text",
200 | lines=4,
201 | max_lines=4,
202 | placeholder="Please Input Text...",
203 | interactive=True,
204 | )
205 | activate_tag_name = gr.State(value="Speaker Embedding")
206 |
207 | with gr.Tabs() as tabs:
208 | with gr.Tab("Speaker Embedding"):
209 | with gr.Column():
210 | with gr.Row(equal_height=True):
211 | with gr.Column():
212 | spk_emb_dir = gr.Textbox(label="Input Speaker Embedding Directory",
213 | placeholder="Please input speaker embedding directory",
214 | value=os.path.abspath(
215 | os.path.join(constants.PROJECT_DIR, "assets/speakers")))
216 | reload_chat_button = gr.Button("Reload", scale=1)
217 | pt_files_dropdown = gr.Dropdown(label="Select Speaker Embedding")
218 |
219 | upload_emb_file = gr.File(label="Upload Speaker Embedding File (.pt)")
220 |
221 | spk_emb_path = gr.Textbox(
222 | label="Speaker Embedding Path",
223 | max_lines=3,
224 | show_copy_button=True,
225 | interactive=True,
226 | scale=2,
227 | )
228 | spk_emb_reset = gr.Button("Reset", scale=1)
229 |
230 | upload_emb_file.upload(update_spk_emb_path, inputs=upload_emb_file, outputs=spk_emb_path)
231 | reload_chat_button.click(
232 | list_pt_files_in_dir, inputs=spk_emb_dir, outputs=pt_files_dropdown
233 | )
234 | pt_files_dropdown.select(
235 | set_spk_emb_path_from_dir, inputs=[spk_emb_dir, pt_files_dropdown], outputs=spk_emb_path
236 | )
237 |
238 | # 点击 Reset 按钮清空 Speaker Embedding Path
239 | spk_emb_reset.click(
240 | reset_spk_emb_path, inputs=None, outputs=[spk_emb_path, pt_files_dropdown]
241 | )
242 |
243 | with gr.Tab("Speaker Audio (ZeroShot)"):
244 | with gr.Column():
245 | with gr.Row(equal_height=True):
246 | sample_audio_input = gr.Audio(
247 | value=None,
248 | type="filepath",
249 | interactive=True,
250 | show_label=False,
251 | waveform_options=gr.WaveformOptions(
252 | sample_rate=24000,
253 | ),
254 | )
255 | sample_text_input = gr.Textbox(
256 | label="Sample Text (ZeroShot)",
257 | lines=4,
258 | max_lines=4,
259 | placeholder="If Sample Audio and Sample Text are available, the Speaker Embedding will be disabled.",
260 | interactive=True,
261 | )
262 | sample_reset = gr.Button("Reset", scale=1)
263 | # 点击 Reset 按钮清空 Sample Audio 和 Sample Text
264 | sample_reset.click(
265 | reset_sample_inputs, inputs=None, outputs=[sample_audio_input, sample_text_input]
266 | )
267 |
268 | with gr.Tab("Speaker Lora"):
269 | with gr.Column():
270 | lora_files = gr.Files(label="Lora files: config.json and safetensors")
271 | spk_lora_path = gr.Textbox(
272 | label="Speaker Lora Path",
273 | max_lines=3,
274 | show_copy_button=True,
275 | interactive=True,
276 | scale=2,
277 | )
278 | lora_files.upload(update_spk_lora_path, inputs=lora_files, outputs=spk_lora_path)
279 | lora_reset = gr.Button("Reset", scale=1)
280 | lora_reset.click(
281 | reset_lora_inputs, inputs=None, outputs=[lora_files, spk_lora_path]
282 | )
283 |
284 | with gr.Row(equal_height=True):
285 | refine_text_checkbox = gr.Checkbox(
286 | label="Refine text", interactive=True, value=True
287 | )
288 | text_prompt = gr.Text(
289 | interactive=True,
290 | value="[oral_2][laugh_0][break_4]",
291 | label="text_prompt"
292 | )
293 | text_temperature_slider = gr.Number(
294 | minimum=0.00001,
295 | maximum=1.0,
296 | value=0.3,
297 | step=0.05,
298 | label="Text Temperature",
299 | interactive=True,
300 | )
301 | text_top_p_slider = gr.Number(
302 | minimum=0.1,
303 | maximum=0.9,
304 | value=0.7,
305 | step=0.05,
306 | label="text_top_P",
307 | interactive=True,
308 | )
309 | text_top_k_slider = gr.Number(
310 | minimum=1,
311 | maximum=30,
312 | value=20,
313 | step=1,
314 | label="text_top_K",
315 | interactive=True,
316 | )
317 | text_seed_input = gr.Number(
318 | label="Text Seed",
319 | interactive=True,
320 | value=1,
321 | minimum=seed_min,
322 | maximum=seed_max,
323 | )
324 | generate_text_seed = gr.Button("\U0001F3B2", interactive=True)
325 |
326 | with gr.Row(equal_height=True):
327 | audio_prompt = gr.Text(
328 | interactive=True,
329 | value="[speed_4]",
330 | label="audio_prompt"
331 | )
332 | audio_temperature_slider = gr.Number(
333 | minimum=0.00001,
334 | maximum=1.0,
335 | step=0.0001,
336 | value=0.0003,
337 | label="Audio Temperature",
338 | interactive=True,
339 | )
340 | audio_top_p_slider = gr.Number(
341 | minimum=0.1,
342 | maximum=0.9,
343 | value=0.7,
344 | step=0.05,
345 | label="audio_top_P",
346 | interactive=True,
347 | )
348 | audio_top_k_slider = gr.Number(
349 | minimum=1,
350 | maximum=100,
351 | value=20,
352 | step=1,
353 | label="audio_top_K",
354 | interactive=True,
355 | )
356 | audio_seed_input = gr.Number(
357 | label="Audio Seed",
358 | interactive=True,
359 | value=1,
360 | minimum=seed_min,
361 | maximum=seed_max,
362 | )
363 | generate_audio_seed = gr.Button("\U0001F3B2", interactive=True)
364 |
365 | with gr.Row():
366 | auto_play_checkbox = gr.Checkbox(
367 | label="Auto Play", value=False, scale=1, interactive=True
368 | )
369 | stream_mode_checkbox = gr.Checkbox(
370 | label="Stream Mode",
371 | value=False,
372 | scale=1,
373 | interactive=True,
374 | )
375 | generate_button = gr.Button(
376 | "Generate", scale=2, variant="primary", interactive=True
377 | )
378 | interrupt_button = gr.Button(
379 | "Interrupt",
380 | scale=2,
381 | variant="stop",
382 | visible=False,
383 | interactive=False,
384 | )
385 |
386 | text_output = gr.Textbox(
387 | label="Output Text",
388 | interactive=False,
389 | show_copy_button=True,
390 | lines=4,
391 | )
392 |
393 | generate_audio_seed.click(generate_seed, outputs=audio_seed_input)
394 | generate_text_seed.click(generate_seed, outputs=text_seed_input)
395 |
396 | @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox])
397 | def make_audio(autoplay, stream):
398 | audio_output = gr.Audio(
399 | label="Output Audio",
400 | value=None,
401 | format="wav",
402 | autoplay=autoplay,
403 | streaming=stream,
404 | interactive=False,
405 | show_label=True,
406 | waveform_options=gr.WaveformOptions(
407 | sample_rate=24000,
408 | ),
409 | show_download_button=True
410 | )
411 |
412 | generate_button.click(
413 | fn=refine_text,
414 | inputs=[
415 | text_input,
416 | text_prompt,
417 | text_temperature_slider,
418 | text_top_p_slider,
419 | text_top_k_slider,
420 | text_seed_input,
421 | refine_text_checkbox,
422 | ],
423 | outputs=text_output,
424 | ).then(
425 | generate_audio,
426 | inputs=[
427 | text_output,
428 | audio_prompt,
429 | audio_temperature_slider,
430 | audio_top_p_slider,
431 | audio_top_k_slider,
432 | spk_emb_path,
433 | stream_mode_checkbox,
434 | audio_seed_input,
435 | sample_text_input,
436 | sample_audio_input,
437 | spk_lora_path
438 | ],
439 | outputs=audio_output,
440 | )
441 |
442 | demo.launch(
443 | server_name=args.server_name,
444 | server_port=args.server_port,
445 | inbrowser=True,
446 | show_api=False,
447 | )
448 |
449 |
450 | if __name__ == "__main__":
451 | parser = argparse.ArgumentParser(description="ChatTTS demo Launch")
452 | parser.add_argument(
453 | "--cfg", type=str, default="configs/infer/chattts_plus.yaml", help="config of chattts plus"
454 | )
455 | parser.add_argument(
456 | "--server_name", type=str, default="0.0.0.0", help="server name"
457 | )
458 | parser.add_argument("--server_port", type=int, default=7890, help="server port")
459 | args = parser.parse_args()
460 |
461 | infer_cfg = OmegaConf.load(args.cfg)
462 | pipe = ChatTTSPlusPipeline(infer_cfg, device=utils.get_inference_device())
463 | main(args)
464 |
--------------------------------------------------------------------------------