├── .gitignore
├── README.md
├── WeClone-audio
├── README.md
└── src
│ ├── Llasa
│ ├── infer.py
│ └── text_to_speech.py
│ ├── SparkTTS.py
│ ├── __init__.py
│ ├── get_sample_audio.py
│ ├── infer.py
│ ├── sample.wav
│ └── server未完工
│ ├── .env.example
│ ├── handle_text.py
│ ├── requirements.txt
│ ├── server.py
│ ├── tts_handler.py
│ └── utils.py
├── data
├── example_chat.csv
├── res_csv
│ ├── pt
│ │ └── dataset_info.json
│ └── sft
│ │ ├── dataset_info-with-his.json
│ │ └── dataset_info.json
└── test_data.json
├── ds_config.json
├── img
├── 1.png
├── 2.png
├── 3.png
├── 4.jpg
└── 5.png
├── make_dataset
├── blocked_words.json
├── csv_to_json-单句回答.py
├── csv_to_json-单句多轮.py
└── csv_to_json.py
├── pyproject.toml
├── requirements.txt
├── settings.json
└── src
├── __init__.py
├── api_service.py
├── cli_demo.py
├── evaluate.py
├── export_model.py
├── template.py
├── test_model.py
├── train_pt.py
├── train_sft.py
├── utils
├── __init__.py
├── config.py
└── utils.py
├── web_demo.py
└── wechat_bot
├── handler
└── text.py
└── main.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | weclone_archive-my/
3 | **/pycache/
4 | events.out.tfevents.*
5 | 归档/
6 | *.pt
7 | *.npz
8 | *nohup.out
9 | *log.txt
10 | *cookie.bin
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
142 |
143 | data/csv.zip
144 | LLaMA-Factory
145 | chatglm3-6b
146 | cache
147 | archive
148 | model_output*
149 | data/test
150 | .vscode
151 | *-my.*
152 | *.csv
153 | test.*
154 | *users.json
155 | WeClone-audio/src/output*.wav
156 | WeClone-audio/uv.lock
157 | Spark-TTS-0.5B/
158 | uv.lock
159 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | ## 核心功能✨
4 | - 💬 使用微信聊天记录微调LLM
5 | - 🎙️ 使用微信语音消息➕0.5B大模型实现高质量声音克隆 👉[WeClone-audio](https://github.com/xming521/WeClone/tree/master/WeClone-audio)
6 | - 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身
7 |
8 | ## 特性与说明📋
9 |
10 | > [!TIP]
11 | > 新特性:[WeClone-audio](https://github.com/xming521/WeClone/tree/master/WeClone-audio) 模块,支持对微信语音进行克隆。
12 |
13 |
14 | > [!IMPORTANT]
15 | > 微调LLM最终效果很大程度取决于聊天数据的数量和质量
16 |
17 | ### 硬件要求
18 |
19 | 目前项目默认使用chatglm3-6b模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法,占用显存更少,需要自行修改模板的system提示词等相关配置。
20 |
21 | 需要显存的估算值:
22 | | 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B |
23 | | ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- |
24 | | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB |
25 | | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB |
26 | | Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB |
27 | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB |
28 | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB |
29 | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB |
30 |
31 |
32 | ### 环境搭建
33 | 建议使用 [uv](https://docs.astral.sh/uv/),这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含xcodec(音频克隆)功能的依赖:
34 | ```bash
35 | git clone https://github.com/xming521/WeClone.git
36 | cd WeClone
37 | uv venv .venv --python=3.9
38 | source .venv/bin/activate
39 | uv pip install --group main -e .
40 | ```
41 |
42 | > [!NOTE]
43 | > 训练以及推理相关配置统一在文件[settings.json](settings.json)
44 |
45 |
46 | ### 数据准备
47 |
48 | 请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人或群聊,然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./data`目录即可,也就是不同人聊天记录的文件夹一起放在 `./data/csv`。 示例数据位于[data/example_chat.csv](data/example_chat.csv)。
49 |
50 | ### 数据预处理
51 |
52 | 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还提供了一个禁用词词库[blocked_words](make_dataset/blocked_words.json),可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。
53 | 执行 `./make_dataset/csv_to_json.py` 脚本对数据进行处理。
54 |
55 | 在同一人连续回答多句的情况下,有三种处理方式:
56 | | 文件 | 处理方式 |
57 | | --- | --- |
58 | | csv_to_json.py | 用逗号连接 |
59 | | csv_to_json-单句回答.py(已废弃) | 只选择最长的回答作为最终数据 |
60 | | csv_to_json-单句多轮.py | 放在了提示词的'history'中 |
61 |
62 | ### 模型下载
63 |
64 | 首选在Hugging Face下载[ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) 模型。如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区,后续训练推理都需要先执行`export USE_MODELSCOPE_HUB=1`来使用魔搭社区的模型。
65 | 由于模型较大,下载过程比较漫长请耐心等待。
66 |
67 | ```bash
68 | export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1`
69 | git lfs install
70 | git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git
71 | ```
72 | 魔搭社区的`modeling_chatglm.py`文件需要更换为Hugging Face的
73 |
74 | ### 配置参数并微调模型
75 |
76 | - (可选)修改 [settings.json](settings.json)选择本地下载好的其他模型。
77 |
78 | - 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。
79 | - 可以根据自己数据集的数量和质量修改`num_train_epochs`、`lora_rank`、`lora_dropout`等参数。
80 |
81 | #### 单卡训练
82 |
83 | 运行 `src/train_sft.py` 进行sft阶段微调,本人loss只降到了3.5左右,降低过多可能会过拟合,我使用了大概2万条整合后的有效数据。
84 |
85 | ```bash
86 | python src/train_sft.py
87 | ```
88 |
89 | #### 多卡训练
90 |
91 | ```bash
92 | uv pip install deepspeed
93 | deepspeed --num_gpus=使用显卡数量 src/train_sft.py
94 | ```
95 |
96 |
97 | ### 使用浏览器demo简单推理
98 |
99 | ```bash
100 | python ./src/web_demo.py
101 | ```
102 |
103 | ### 使用接口进行推理
104 |
105 | ```bash
106 | python ./src/api_service.py
107 | ```
108 |
109 | ### 使用常见聊天问题测试
110 |
111 | ```bash
112 | python ./src/api_service.py
113 | python ./src/test_model.py
114 | ```
115 |
116 | ### 部署到聊天机器人
117 |
118 | #### AstrBot方案
119 | [AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。
120 |
121 | 使用步骤:
122 | 1. 部署 AstrBot
123 | 2. 在 AstrBot 中部署消息平台
124 | 3. 执行 `python ./src/api_service.py ` 启动api服务
125 | 4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo
126 | 5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off reminder`,否则会没有微调后的效果。
127 | 6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。
128 | 
129 |
130 |
131 |
132 |
133 |
134 | itchat方案(已弃用)
135 |
136 | > [!IMPORTANT]
137 | > 微信有封号风险,建议使用小号,并且必须绑定银行卡才能使用
138 |
139 | ```bash
140 | python ./src/api_service.py # 先启动api服务
141 | python ./src/wechat_bot/main.py
142 | ```
143 |
144 | 默认在终端显示二维码,扫码登录即可。可以私聊或者在群聊中@机器人使用。
145 |
146 |
147 | ### 截图
148 |
149 | 
150 | 
151 | 
152 | 
153 |
154 |
155 |
156 | # 免责声明
157 | > [!CAUTION]
158 | > 请勿用于非法用途,否则后果自负。
159 |
160 | 1. 使用目的
161 |
162 | * 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。
163 | * 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。
164 |
165 | 2. 使用期限
166 |
167 | * 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。
168 |
169 | 3. 操作规范
170 |
171 | * 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。
172 | * 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。
173 |
174 | 4. 免责声明接受
175 |
176 | * 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它;
177 |
178 | 5. 禁止用于非法测试或渗透
179 |
180 | * 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。
181 | * 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。
182 |
183 | 6. 免责声明修改
184 |
185 | * 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。
186 |
187 | 7. 其他
188 |
189 | * 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。
190 |
191 | * 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。
192 |
193 |
194 | 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。
195 |
196 |
197 |
198 |
199 |
200 | ## ⭐ Star History
201 | > [!TIP]
202 | > 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢
203 |
204 |
205 |
206 | [](https://www.star-history.com/#xming521/WeClone&Date)
207 |
208 |
209 |
210 |
211 | 克隆我们,保留那灵魂的芬芳
212 |
--------------------------------------------------------------------------------
/WeClone-audio/README.md:
--------------------------------------------------------------------------------
1 | # WeClone-audio 模块
2 |
3 | WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。
4 | ### 显存需求
5 | **Spark-TTS** 推荐
6 | - **0.5B 模型**: 约 4GB 显存
7 |
8 | **Llasa**
9 | - **3B 模型**: 约 16GB 显存
10 | - **1B 模型**: 约 9GB 显存
11 |
12 |
13 |
14 |
15 | ## 1. 导出微信语音数据
16 |
17 | ### 1.1 准备工作
18 | - 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录
19 | - 下载软件并解密数据库
20 | - 点击聊天备份,导出类型选择"解密文件"
21 |
22 | ### 1.2 环境配置
23 | 语音导出仅支持Windows环境
24 | WeClone Audio使用uv作为包管理器。
25 | ```bash
26 | # 为 PyWxDump 创建 Python 环境和安装依赖
27 | #
28 | uv venv .venv-wx --python=3.9
29 | source .venv-wx/bin/activate
30 | # 安装 wx 依赖组
31 | uv pip install --group wx -e .
32 | ```
33 |
34 | ### 1.3 导出语音文件
35 | ```bash
36 | python ./WeClone-audio/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段"
37 | ```
38 |
39 | ## 2. 语音合成推理
40 | ### Spark-TTS模型
41 |
42 | **环境安装**
43 | 可不创建新环境,直接安装依赖组到WeClone共主环境
44 |
45 | ```bash
46 | uv venv .venv-sparktts --python=3.9
47 | source .venv-sparktts/bin/activate
48 | uv pip install --group sparktts -e .
49 |
50 | cd WeClone-audio/src
51 | git clone https://github.com/SparkAudio/Spark-TTS.git
52 | ```
53 |
54 | **模型下载**
55 |
56 | 通过python下载:
57 | ```python
58 | from huggingface_hub import snapshot_download
59 |
60 | snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B")
61 | ```
62 |
63 | 或通过git下载:
64 | ```sh
65 | cd WeClone-audio
66 | mkdir -p pretrained_models
67 |
68 | # Make sure you have git-lfs installed (https://git-lfs.com)
69 | git lfs install
70 | git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B
71 | ```
72 | 使用代码推理
73 | ```python
74 | import os
75 | import SparkTTS
76 | import soundfile as sf
77 | import torch
78 |
79 | from SparkTTS import SparkTTS
80 |
81 | model = SparkTTS("WeClone-audio/pretrained_models/Spark-TTS-0.5B", "cuda")
82 |
83 |
84 | with torch.no_grad():
85 | wav = model.inference(
86 | text="晚上好啊,小可爱们,该睡觉了哦",
87 | prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"),
88 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。",
89 | )
90 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000)
91 | ```
92 | ### Llasa模型
93 | ### 2.1 环境配置
94 | ```bash
95 | # 创建并配置推理环境
96 | ## 可不创建新环境,与LLaMA-Factory环境共用
97 | uv venv .venv-xcodec --python=3.9
98 | source .venv-xcodec/bin/activate
99 | uv pip install --group xcodec -e .
100 | # 退出环境
101 | deactivate
102 |
103 | # 系统依赖安装(如果需要)
104 | sudo apt install python3-dev
105 | sudo apt install build-essential
106 | ```
107 |
108 | ### 2.2 使用代码推理
109 | 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
110 | ```python
111 | import os
112 | import soundfile as sf
113 | from text_to_speech import TextToSpeech
114 |
115 |
116 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本
117 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径
118 | tts = TextToSpeech(sample_audio_path, sample_audio_text)
119 | target_text = "晚上好啊" # 生成目标文本
120 | result = tts.infer(target_text)
121 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频
122 | ```
123 |
124 |
--------------------------------------------------------------------------------
/WeClone-audio/src/Llasa/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import soundfile as sf
3 | from text_to_speech import TextToSpeech
4 |
5 |
6 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本
7 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径
8 | tts = TextToSpeech(sample_audio_path, sample_audio_text)
9 | target_text = "晚上好啊" # 生成目标文本
10 | result = tts.infer(target_text)
11 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频
12 |
13 |
--------------------------------------------------------------------------------
/WeClone-audio/src/Llasa/text_to_speech.py:
--------------------------------------------------------------------------------
1 | import os
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | import torch
4 | import soundfile as sf
5 | from xcodec2.modeling_xcodec2 import XCodec2Model
6 | import torchaudio
7 |
8 |
9 | class TextToSpeech:
10 | def __init__(self, sample_audio_path, sample_audio_text):
11 | self.sample_audio_text = sample_audio_text
12 | # 初始化模型
13 | llasa_3b = "HKUSTAudio/Llasa-3B"
14 | xcodec2 = "HKUSTAudio/xcodec2"
15 |
16 | self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b)
17 | self.llasa_3b_model = AutoModelForCausalLM.from_pretrained(
18 | llasa_3b,
19 | trust_remote_code=True,
20 | device_map="auto",
21 | )
22 | self.llasa_3b_model.eval()
23 |
24 | self.xcodec_model = XCodec2Model.from_pretrained(xcodec2)
25 | self.xcodec_model.eval().cuda()
26 |
27 | # 处理音频
28 | waveform, sample_rate = torchaudio.load(sample_audio_path)
29 | if len(waveform[0]) / sample_rate > 15:
30 | print("已将音频裁剪至前15秒。")
31 | waveform = waveform[:, : sample_rate * 15]
32 |
33 | # 检查音频是否为立体声
34 | if waveform.size(0) > 1:
35 | waveform_mono = torch.mean(waveform, dim=0, keepdim=True)
36 | else:
37 | waveform_mono = waveform
38 |
39 | self.prompt_wav = torchaudio.transforms.Resample(
40 | orig_freq=sample_rate, new_freq=16000
41 | )(waveform_mono)
42 |
43 | # Encode the prompt wav
44 | vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav)
45 | vq_code_prompt = vq_code_prompt[0, 0, :]
46 | self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt)
47 | self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>")
48 |
49 | def ids_to_speech_tokens(self, speech_ids):
50 | speech_tokens_str = []
51 | for speech_id in speech_ids:
52 | speech_tokens_str.append(f"<|s_{speech_id}|>")
53 | return speech_tokens_str
54 |
55 | def extract_speech_ids(self, speech_tokens_str):
56 | speech_ids = []
57 | for token_str in speech_tokens_str:
58 | if token_str.startswith("<|s_") and token_str.endswith("|>"):
59 | num_str = token_str[4:-2]
60 | num = int(num_str)
61 | speech_ids.append(num)
62 | else:
63 | print(f"Unexpected token: {token_str}")
64 | return speech_ids
65 |
66 | @torch.inference_mode()
67 | def infer(self, target_text):
68 | if len(target_text) == 0:
69 | return None
70 | elif len(target_text) > 300:
71 | print("文本过长,请保持在300字符以内。")
72 | target_text = target_text[:300]
73 |
74 | input_text = self.sample_audio_text + " " + target_text
75 |
76 | formatted_text = (
77 | f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>"
78 | )
79 |
80 | chat = [
81 | {
82 | "role": "user",
83 | "content": "Convert the text to speech:" + formatted_text,
84 | },
85 | {
86 | "role": "assistant",
87 | "content": "<|SPEECH_GENERATION_START|>"
88 | + "".join(self.speech_ids_prefix),
89 | },
90 | ]
91 |
92 | input_ids = self.tokenizer.apply_chat_template(
93 | chat, tokenize=True, return_tensors="pt", continue_final_message=True
94 | )
95 | input_ids = input_ids.to("cuda")
96 |
97 | outputs = self.llasa_3b_model.generate(
98 | input_ids,
99 | max_length=2048,
100 | eos_token_id=self.speech_end_id,
101 | do_sample=True,
102 | top_p=1,
103 | temperature=0.8,
104 | )
105 | generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1]
106 |
107 | speech_tokens = self.tokenizer.batch_decode(
108 | generated_ids, skip_special_tokens=True
109 | )
110 |
111 | speech_tokens = self.extract_speech_ids(speech_tokens)
112 | speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0)
113 |
114 | gen_wav = self.xcodec_model.decode_code(speech_tokens)
115 | gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:]
116 |
117 | return (16000, gen_wav[0, 0, :].cpu().numpy())
118 |
119 |
120 | if __name__ == "__main__":
121 | # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。
122 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。"
123 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav")
124 |
125 | tts = TextToSpeech(sample_audio_path, sample_audio_text)
126 | target_text = "晚上好啊,吃了吗您"
127 | result = tts.infer(target_text)
128 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0])
129 | target_text = "我是老北京正黄旗!"
130 | result = tts.infer(target_text)
131 | sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0])
132 |
--------------------------------------------------------------------------------
/WeClone-audio/src/SparkTTS.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 | from typing import Tuple
4 | from pathlib import Path
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | import os
7 | import sys
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS")))
9 | from sparktts.utils.file import load_config
10 | from sparktts.models.audio_tokenizer import BiCodecTokenizer
11 | from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP
12 |
13 |
14 | class SparkTTS:
15 | """
16 | Spark-TTS for text-to-speech generation.
17 | """
18 |
19 | def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")):
20 | """
21 | Initializes the SparkTTS model with the provided configurations and device.
22 |
23 | Args:
24 | model_dir (Path): Directory containing the model and config files.
25 | device (torch.device): The device (CPU/GPU) to run the model on.
26 | """
27 | self.device = device
28 | self.model_dir = model_dir
29 | self.configs = load_config(f"{model_dir}/config.yaml")
30 | self.sample_rate = self.configs["sample_rate"]
31 | self._initialize_inference()
32 |
33 | def _initialize_inference(self):
34 | """Initializes the tokenizer, model, and audio tokenizer for inference."""
35 | self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM")
36 | self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM")
37 | self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device)
38 | self.model.to(self.device)
39 |
40 | def process_prompt(
41 | self,
42 | text: str,
43 | prompt_speech_path: Path,
44 | prompt_text: str = None,
45 | ) -> Tuple[str, torch.Tensor]:
46 | """
47 | Process input for voice cloning.
48 |
49 | Args:
50 | text (str): The text input to be converted to speech.
51 | prompt_speech_path (Path): Path to the audio file used as a prompt.
52 | prompt_text (str, optional): Transcript of the prompt audio.
53 |
54 | Return:
55 | Tuple[str, torch.Tensor]: Input prompt; global tokens
56 | """
57 |
58 | global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize(
59 | prompt_speech_path
60 | )
61 | global_tokens = "".join(
62 | [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
63 | )
64 |
65 | # Prepare the input tokens for the model
66 | if prompt_text is not None:
67 | semantic_tokens = "".join(
68 | [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
69 | )
70 | inputs = [
71 | TASK_TOKEN_MAP["tts"],
72 | "<|start_content|>",
73 | prompt_text,
74 | text,
75 | "<|end_content|>",
76 | "<|start_global_token|>",
77 | global_tokens,
78 | "<|end_global_token|>",
79 | "<|start_semantic_token|>",
80 | semantic_tokens,
81 | ]
82 | else:
83 | inputs = [
84 | TASK_TOKEN_MAP["tts"],
85 | "<|start_content|>",
86 | text,
87 | "<|end_content|>",
88 | "<|start_global_token|>",
89 | global_tokens,
90 | "<|end_global_token|>",
91 | ]
92 |
93 | inputs = "".join(inputs)
94 |
95 | return inputs, global_token_ids
96 |
97 | def process_prompt_control(
98 | self,
99 | gender: str,
100 | pitch: str,
101 | speed: str,
102 | text: str,
103 | ):
104 | """
105 | Process input for voice creation.
106 |
107 | Args:
108 | gender (str): female | male.
109 | pitch (str): very_low | low | moderate | high | very_high
110 | speed (str): very_low | low | moderate | high | very_high
111 | text (str): The text input to be converted to speech.
112 |
113 | Return:
114 | str: Input prompt
115 | """
116 | assert gender in GENDER_MAP.keys()
117 | assert pitch in LEVELS_MAP.keys()
118 | assert speed in LEVELS_MAP.keys()
119 |
120 | gender_id = GENDER_MAP[gender]
121 | pitch_level_id = LEVELS_MAP[pitch]
122 | speed_level_id = LEVELS_MAP[speed]
123 |
124 | pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
125 | speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
126 | gender_tokens = f"<|gender_{gender_id}|>"
127 |
128 | attribte_tokens = "".join(
129 | [gender_tokens, pitch_label_tokens, speed_label_tokens]
130 | )
131 |
132 | control_tts_inputs = [
133 | TASK_TOKEN_MAP["controllable_tts"],
134 | "<|start_content|>",
135 | text,
136 | "<|end_content|>",
137 | "<|start_style_label|>",
138 | attribte_tokens,
139 | "<|end_style_label|>",
140 | ]
141 |
142 | return "".join(control_tts_inputs)
143 |
144 | @torch.no_grad()
145 | def inference(
146 | self,
147 | text: str,
148 | prompt_speech_path: Path = None,
149 | prompt_text: str = None,
150 | gender: str = None,
151 | pitch: str = None,
152 | speed: str = None,
153 | temperature: float = 0.8,
154 | top_k: float = 50,
155 | top_p: float = 0.95,
156 | ) -> torch.Tensor:
157 | """
158 | Performs inference to generate speech from text, incorporating prompt audio and/or text.
159 |
160 | Args:
161 | text (str): The text input to be converted to speech.
162 | prompt_speech_path (Path): Path to the audio file used as a prompt.
163 | prompt_text (str, optional): Transcript of the prompt audio.
164 | gender (str): female | male.
165 | pitch (str): very_low | low | moderate | high | very_high
166 | speed (str): very_low | low | moderate | high | very_high
167 | temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8.
168 | top_k (float, optional): Top-k sampling parameter. Default is 50.
169 | top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95.
170 |
171 | Returns:
172 | torch.Tensor: Generated waveform as a tensor.
173 | """
174 | if gender is not None:
175 | prompt = self.process_prompt_control(gender, pitch, speed, text)
176 |
177 | else:
178 | prompt, global_token_ids = self.process_prompt(
179 | text, prompt_speech_path, prompt_text
180 | )
181 | model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device)
182 |
183 | # Generate speech using the model
184 | generated_ids = self.model.generate(
185 | **model_inputs,
186 | max_new_tokens=3000,
187 | do_sample=True,
188 | top_k=top_k,
189 | top_p=top_p,
190 | temperature=temperature,
191 | )
192 |
193 | # Trim the output tokens to remove the input tokens
194 | generated_ids = [
195 | output_ids[len(input_ids):]
196 | for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
197 | ]
198 |
199 | # Decode the generated tokens into text
200 | predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
201 |
202 | # Extract semantic token IDs from the generated text
203 | pred_semantic_ids = (
204 | torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)])
205 | .long()
206 | .unsqueeze(0)
207 | )
208 |
209 | if gender is not None:
210 | global_token_ids = (
211 | torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)])
212 | .long()
213 | .unsqueeze(0)
214 | .unsqueeze(0)
215 | )
216 |
217 | # Convert semantic tokens back to waveform
218 | wav = self.audio_tokenizer.detokenize(
219 | global_token_ids.to(self.device).squeeze(0),
220 | pred_semantic_ids.to(self.device),
221 | )
222 |
223 | return wav
224 |
--------------------------------------------------------------------------------
/WeClone-audio/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/WeClone-audio/src/__init__.py
--------------------------------------------------------------------------------
/WeClone-audio/src/get_sample_audio.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from pywxdump.db import MediaHandler
4 |
5 | def main():
6 | parser = argparse.ArgumentParser(description='Extract audio from WeChat database')
7 | parser.add_argument('--db-path', type=str, required=True,
8 | help='Path to WeChat database file')
9 | parser.add_argument('--MsgSvrID', type=str, required=True,
10 | help='Message server ID of the audio')
11 | parser.add_argument('--save-path', type=str,
12 | default=os.path.join(os.path.dirname(__file__), 'sample.wav'),
13 | help='Path to save the audio file (default: sample.wav in script directory)')
14 | parser.add_argument('--rate', type=int, default=24000,
15 | help='Sample rate for audio conversion (default: 24000)')
16 |
17 | args = parser.parse_args()
18 |
19 | config = {
20 | "key": "test1",
21 | "type": "sqlite",
22 | "path": args.db_path,
23 | }
24 |
25 | t1 = MediaHandler(config)
26 | t1.get_audio(
27 | MsgSvrID=args.msg_id,
28 | is_play=True,
29 | is_wave=True,
30 | save_path=args.save_path,
31 | rate=args.rate,
32 | )
33 |
34 | if __name__ == "__main__":
35 | main()
36 |
--------------------------------------------------------------------------------
/WeClone-audio/src/infer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SparkTTS
3 | import soundfile as sf
4 | import torch
5 |
6 | from SparkTTS import SparkTTS
7 |
8 | model = SparkTTS("WeClone-audio/pretrained_models/Spark-TTS-0.5B", "cuda")
9 |
10 |
11 | with torch.no_grad():
12 | wav = model.inference(
13 | text="晚上好啊,小可爱们,该睡觉了哦",
14 | prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"),
15 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。",
16 | )
17 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000)
18 | print("生成成功!")
19 |
--------------------------------------------------------------------------------
/WeClone-audio/src/sample.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/WeClone-audio/src/sample.wav
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/.env.example:
--------------------------------------------------------------------------------
1 | API_KEY=your_api_key_here
2 | PORT=5050
3 |
4 | DEFAULT_VOICE=en-US-AvaNeural
5 | DEFAULT_RESPONSE_FORMAT=mp3
6 | DEFAULT_SPEED=1.0
7 |
8 | DEFAULT_LANGUAGE=en-US
9 |
10 | REQUIRE_API_KEY=True
11 |
12 | REMOVE_FILTER=False
13 |
14 | EXPAND_API=True
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/handle_text.py:
--------------------------------------------------------------------------------
1 | import re
2 | import emoji
3 |
4 | def prepare_tts_input_with_context(text: str) -> str:
5 | """
6 | Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints
7 | for certain Markdown elements like headers. Preserves paragraph separation.
8 |
9 | Args:
10 | text (str): The raw text containing Markdown or other formatting.
11 |
12 | Returns:
13 | str: Cleaned text with contextual hints suitable for TTS input.
14 | """
15 |
16 | # Remove emojis
17 | text = emoji.replace_emoji(text, replace='')
18 |
19 | # Add context for headers
20 | def header_replacer(match):
21 | level = len(match.group(1)) # Number of '#' symbols
22 | header_text = match.group(2).strip()
23 | if level == 1:
24 | return f"Title — {header_text}\n"
25 | elif level == 2:
26 | return f"Section — {header_text}\n"
27 | else:
28 | return f"Subsection — {header_text}\n"
29 |
30 | text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE)
31 |
32 | # Announce links (currently commented out for potential future use)
33 | # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text)
34 |
35 | # Remove links while keeping the link text
36 | text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text)
37 |
38 | # Describe inline code
39 | text = re.sub(r"`([^`]+)`", r"code snippet: \1", text)
40 |
41 | # Remove bold/italic symbols but keep the content
42 | text = re.sub(r"(\*\*|__|\*|_)", '', text)
43 |
44 | # Remove code blocks (multi-line) with a description
45 | text = re.sub(r"```([\s\S]+?)```", r"(code block omitted)", text)
46 |
47 | # Remove image syntax but add alt text if available
48 | text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text)
49 |
50 | # Remove HTML tags
51 | text = re.sub(r"?[^>]+(>|$)", '', text)
52 |
53 | # Normalize line breaks
54 | text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation
55 |
56 | # Replace multiple spaces within lines
57 | text = re.sub(r" {2,}", ' ', text)
58 |
59 | # Trim leading and trailing whitespace from the whole text
60 | text = text.strip()
61 |
62 | return text
63 |
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/requirements.txt:
--------------------------------------------------------------------------------
1 | flask
2 | gevent
3 | python-dotenv
4 | edge-tts
5 | emoji
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/server.py:
--------------------------------------------------------------------------------
1 | # server.py
2 |
3 | from flask import Flask, request, send_file, jsonify
4 | from gevent.pywsgi import WSGIServer
5 | from dotenv import load_dotenv
6 | import os
7 |
8 | from handle_text import prepare_tts_input_with_context
9 | from tts_handler import generate_speech, get_models, get_voices
10 | from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES
11 |
12 | app = Flask(__name__)
13 | load_dotenv()
14 |
15 | API_KEY = os.getenv('API_KEY', 'your_api_key_here')
16 | PORT = int(os.getenv('PORT', 5050))
17 |
18 | DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural')
19 | DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3')
20 | DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0))
21 |
22 | REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False)
23 | EXPAND_API = getenv_bool('EXPAND_API', True)
24 |
25 | # DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1')
26 |
27 | @app.route('/v1/audio/speech', methods=['POST'])
28 | @app.route('/audio/speech', methods=['POST']) # Add this line for the alias
29 | @require_api_key
30 | def text_to_speech():
31 | data = request.json
32 | if not data or 'input' not in data:
33 | return jsonify({"error": "Missing 'input' in request body"}), 400
34 |
35 | text = data.get('input')
36 |
37 | if not REMOVE_FILTER:
38 | text = prepare_tts_input_with_context(text)
39 |
40 | # model = data.get('model', DEFAULT_MODEL)
41 | voice = data.get('voice', DEFAULT_VOICE)
42 |
43 | response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT)
44 | speed = float(data.get('speed', DEFAULT_SPEED))
45 |
46 | mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg")
47 |
48 | # Generate the audio file in the specified format with speed adjustment
49 | output_file_path = generate_speech(text, voice, response_format, speed)
50 |
51 | # Return the file with the correct MIME type
52 | return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}")
53 |
54 | @app.route('/v1/models', methods=['GET', 'POST'])
55 | @app.route('/models', methods=['GET', 'POST'])
56 | @require_api_key
57 | def list_models():
58 | return jsonify({"data": get_models()})
59 |
60 | @app.route('/v1/voices', methods=['GET', 'POST'])
61 | @app.route('/voices', methods=['GET', 'POST'])
62 | @require_api_key
63 | def list_voices():
64 | specific_language = None
65 |
66 | data = request.args if request.method == 'GET' else request.json
67 | if data and ('language' in data or 'locale' in data):
68 | specific_language = data.get('language') if 'language' in data else data.get('locale')
69 |
70 | return jsonify({"voices": get_voices(specific_language)})
71 |
72 | @app.route('/v1/voices/all', methods=['GET', 'POST'])
73 | @app.route('/voices/all', methods=['GET', 'POST'])
74 | @require_api_key
75 | def list_all_voices():
76 | return jsonify({"voices": get_voices('all')})
77 |
78 | """
79 | Support for ElevenLabs and Azure AI Speech
80 | (currently in beta)
81 | """
82 |
83 | # http://localhost:5050/elevenlabs/v1/text-to-speech
84 | # http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural
85 | @app.route('/elevenlabs/v1/text-to-speech/', methods=['POST'])
86 | @require_api_key
87 | def elevenlabs_tts(voice_id):
88 | if not EXPAND_API:
89 | return jsonify({"error": f"Endpoint not allowed"}), 500
90 |
91 | # Parse the incoming JSON payload
92 | try:
93 | payload = request.json
94 | if not payload or 'text' not in payload:
95 | return jsonify({"error": "Missing 'text' in request body"}), 400
96 | except Exception as e:
97 | return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400
98 |
99 | text = payload['text']
100 |
101 | if not REMOVE_FILTER:
102 | text = prepare_tts_input_with_context(text)
103 |
104 | voice = voice_id # ElevenLabs uses the voice_id in the URL
105 |
106 | # Use default settings for edge-tts
107 | response_format = 'mp3'
108 | speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED)
109 |
110 | # Generate speech using edge-tts
111 | try:
112 | output_file_path = generate_speech(text, voice, response_format, speed)
113 | except Exception as e:
114 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500
115 |
116 | # Return the generated audio file
117 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3")
118 |
119 | # tts.speech.microsoft.com/cognitiveservices/v1
120 | # https://{region}.tts.speech.microsoft.com/cognitiveservices/v1
121 | # http://localhost:5050/azure/cognitiveservices/v1
122 | @app.route('/azure/cognitiveservices/v1', methods=['POST'])
123 | @require_api_key
124 | def azure_tts():
125 | if not EXPAND_API:
126 | return jsonify({"error": f"Endpoint not allowed"}), 500
127 |
128 | # Parse the SSML payload
129 | try:
130 | ssml_data = request.data.decode('utf-8')
131 | if not ssml_data:
132 | return jsonify({"error": "Missing SSML payload"}), 400
133 |
134 | # Extract the text and voice from SSML
135 | from xml.etree import ElementTree as ET
136 | root = ET.fromstring(ssml_data)
137 | text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text
138 | voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name')
139 | except Exception as e:
140 | return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400
141 |
142 | # Use default settings for edge-tts
143 | response_format = 'mp3'
144 | speed = DEFAULT_SPEED
145 |
146 | if not REMOVE_FILTER:
147 | text = prepare_tts_input_with_context(text)
148 |
149 | # Generate speech using edge-tts
150 | try:
151 | output_file_path = generate_speech(text, voice, response_format, speed)
152 | except Exception as e:
153 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500
154 |
155 | # Return the generated audio file
156 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3")
157 |
158 | print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API")
159 | print(f" ")
160 | print(f" * Serving OpenAI Edge TTS")
161 | print(f" * Server running on http://localhost:{PORT}")
162 | print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech")
163 | print(f" ")
164 |
165 | if __name__ == '__main__':
166 | http_server = WSGIServer(('0.0.0.0', PORT), app)
167 | http_server.serve_forever()
168 |
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/tts_handler.py:
--------------------------------------------------------------------------------
1 | import edge_tts
2 | import asyncio
3 | import tempfile
4 | import subprocess
5 | import os
6 | from pathlib import Path
7 |
8 | # Language default (environment variable)
9 | DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US')
10 |
11 | # OpenAI voice names mapped to edge-tts equivalents
12 | voice_mapping = {
13 | 'alloy': 'en-US-AvaNeural',
14 | 'echo': 'en-US-AndrewNeural',
15 | 'fable': 'en-GB-SoniaNeural',
16 | 'onyx': 'en-US-EricNeural',
17 | 'nova': 'en-US-SteffanNeural',
18 | 'shimmer': 'en-US-EmmaNeural'
19 | }
20 |
21 | def is_ffmpeg_installed():
22 | """Check if FFmpeg is installed and accessible."""
23 | try:
24 | subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
25 | return True
26 | except (subprocess.CalledProcessError, FileNotFoundError):
27 | return False
28 |
29 | async def _generate_audio(text, voice, response_format, speed):
30 | """Generate TTS audio and optionally convert to a different format."""
31 | # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice
32 | edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is
33 |
34 | # Generate the TTS output in mp3 format first
35 | temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
36 |
37 | # Convert speed to SSML rate format
38 | try:
39 | speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%"
40 | except Exception as e:
41 | print(f"Error converting speed: {e}. Defaulting to +0%.")
42 | speed_rate = "+0%"
43 |
44 | # Generate the MP3 file
45 | communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate)
46 | await communicator.save(temp_output_file.name)
47 |
48 | # If the requested format is mp3, return the generated file directly
49 | if response_format == "mp3":
50 | return temp_output_file.name
51 |
52 | # Check if FFmpeg is installed
53 | if not is_ffmpeg_installed():
54 | print("FFmpeg is not available. Returning unmodified mp3 file.")
55 | return temp_output_file.name
56 |
57 | # Create a new temporary file for the converted output
58 | converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}")
59 |
60 | # Build the FFmpeg command
61 | ffmpeg_command = [
62 | "ffmpeg",
63 | "-i", temp_output_file.name, # Input file
64 | "-c:a", {
65 | "aac": "aac",
66 | "mp3": "libmp3lame",
67 | "wav": "pcm_s16le",
68 | "opus": "libopus",
69 | "flac": "flac"
70 | }.get(response_format, "aac"), # Default to AAC if unknown
71 | "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV
72 | "-f", {
73 | "aac": "mp4", # AAC in MP4 container
74 | "mp3": "mp3",
75 | "wav": "wav",
76 | "opus": "ogg",
77 | "flac": "flac"
78 | }.get(response_format, response_format), # Default to matching format
79 | "-y", # Overwrite without prompt
80 | converted_output_file.name # Output file
81 | ]
82 |
83 | try:
84 | # Run FFmpeg command and ensure no errors occur
85 | subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
86 | except subprocess.CalledProcessError as e:
87 | raise RuntimeError(f"FFmpeg error during audio conversion: {e}")
88 |
89 | # Clean up the original temporary file
90 | Path(temp_output_file.name).unlink(missing_ok=True)
91 |
92 | return converted_output_file.name
93 |
94 | def generate_speech(text, voice, response_format, speed=1.0):
95 | return asyncio.run(_generate_audio(text, voice, response_format, speed))
96 |
97 | def get_models():
98 | return [
99 | {"id": "tts-1", "name": "Text-to-speech v1"},
100 | {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"}
101 | ]
102 |
103 | async def _get_voices(language=None):
104 | # List all voices, filter by language if specified
105 | all_voices = await edge_tts.list_voices()
106 | language = language or DEFAULT_LANGUAGE # Use default if no language specified
107 | filtered_voices = [
108 | {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']}
109 | for v in all_voices if language == 'all' or language is None or v['Locale'] == language
110 | ]
111 | return filtered_voices
112 |
113 | def get_voices(language=None):
114 | return asyncio.run(_get_voices(language))
115 |
116 | def speed_to_rate(speed: float) -> str:
117 | """
118 | Converts a multiplicative speed value to the edge-tts "rate" format.
119 |
120 | Args:
121 | speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%).
122 |
123 | Returns:
124 | str: The formatted "rate" string (e.g., "+50%" or "-50%").
125 | """
126 | if speed < 0 or speed > 2:
127 | raise ValueError("Speed must be between 0 and 2 (inclusive).")
128 |
129 | # Convert speed to percentage change
130 | percentage_change = (speed - 1) * 100
131 |
132 | # Format with a leading "+" or "-" as required
133 | return f"{percentage_change:+.0f}%"
134 |
--------------------------------------------------------------------------------
/WeClone-audio/src/server未完工/utils.py:
--------------------------------------------------------------------------------
1 | # utils.py
2 |
3 | from flask import request, jsonify
4 | from functools import wraps
5 | import os
6 | from dotenv import load_dotenv
7 |
8 | load_dotenv()
9 |
10 | def getenv_bool(name: str, default: bool = False) -> bool:
11 | return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t")
12 |
13 | API_KEY = os.getenv('API_KEY', 'your_api_key_here')
14 | REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True)
15 |
16 | def require_api_key(f):
17 | @wraps(f)
18 | def decorated_function(*args, **kwargs):
19 | if not REQUIRE_API_KEY:
20 | return f(*args, **kwargs)
21 | auth_header = request.headers.get('Authorization')
22 | if not auth_header or not auth_header.startswith('Bearer '):
23 | return jsonify({"error": "Missing or invalid API key"}), 401
24 | token = auth_header.split('Bearer ')[1]
25 | if token != API_KEY:
26 | return jsonify({"error": "Invalid API key"}), 401
27 | return f(*args, **kwargs)
28 | return decorated_function
29 |
30 | # Mapping of audio format to MIME type
31 | AUDIO_FORMAT_MIME_TYPES = {
32 | "mp3": "audio/mpeg",
33 | "opus": "audio/ogg",
34 | "aac": "audio/aac",
35 | "flac": "audio/flac",
36 | "wav": "audio/wav",
37 | "pcm": "audio/L16"
38 | }
39 |
--------------------------------------------------------------------------------
/data/example_chat.csv:
--------------------------------------------------------------------------------
1 | id,MsgSvrID,type_name,is_sender,talker,room_name,content,CreateTime
2 | 1,953020244103908134,系统通知,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""你已添加了白桃乌龙,现在可以开始聊天了。""}",2023-08-12 23:22:41
3 | 2,7594861486645126963,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哦吼""}",2023-08-12 23:22:54
4 | 3,5795621731176683438,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""咱们老家""}",2023-08-12 23:23:02
5 | 4,3470072877112832166,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""离得近嘞""}",2023-08-12 23:23:05
6 | 5,4123958315588848926,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=8f29c306e125f7514246f4f82887cc29&filekey=30350201010421301f020201060402534804108f29c306e125f7514246f4f82887cc290203008884040d00000004627466730000000131&hy=SH&storeid=32303232303531373136313532383030306137303837333237346335383237653661386430393030303030313036&bizid=1023"", ""msg"": ""表情""}",2023-08-12 23:23:26
7 | 6,1893107158254293225,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""是耶""}",2023-08-12 23:23:31
8 | 7,5967747793341844139,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""挺好""}",2023-08-12 23:23:37
9 | 8,365944262262687504,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈""}",2023-08-12 23:23:41
10 | 9,4210929420150253626,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""困嘞困嘞""}",2023-08-12 23:24:26
11 | 10,3243868033522068109,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""俺先睡了""}",2023-08-12 23:24:30
12 | 11,7049787380472735099,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""明天再聊""}",2023-08-12 23:24:33
13 | 12,3092986777523361373,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好der""}",2023-08-12 23:24:47
14 | 13,8642915817867986857,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://mmbiz.qpic.cn/mmemoticon/ajNVdqHZLLACNZkcWvHBNURJvCPe6PQ3BDf8deqLjHGricqHbCbV0iagpHmCat5Mo3/0"", ""msg"": ""表情""}",2023-08-12 23:24:50
15 | 14,422457693648526604,动画表情,1,我,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=65a36156989a0bd09bd3d8b8fee3cb2f&filekey=30350201010421301f020201060402535a041065a36156989a0bd09bd3d8b8fee3cb2f0203008069040d00000004627466730000000132&hy=SZ&storeid=26438315e000ec0f96c8595840000010600004f50535a0b492a91e683f6a79&bizid=1023"", ""msg"": ""表情""}",2023-08-12 23:25:04
16 | 15,2311761750664470975,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\82332bbba85f6bb3cf85823037f3a99e.dat"", ""msg"": ""图片""}",2023-08-13 14:01:41
17 | 16,4293887349744541268,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\9c2e52835ddf872089b6fff3497d6f44.dat"", ""msg"": ""图片""}",2023-08-13 14:01:43
18 | 17,5053822429327172633,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""今天天气不错嘞""}",2023-08-13 14:01:46
19 | 18,2705637163783232848,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""就是有点hot""}",2023-08-13 14:02:07
20 | 19,9106315879092141637,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/275/20304/stodownload?m=39ec76de39cf2f41e0f19fae9752f2e1&filekey=30340201010420301e0202011304025348041039ec76de39cf2f41e0f19fae9752f2e102024a31040d00000004627466730000000132&hy=SH&storeid=264795d250002b74e000000000000011300004f50534801fe5b01e648f3633&bizid=1023"", ""msg"": ""表情""}",2023-08-13 14:05:59
21 | 20,4919372022511043672,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哪天不hot""}",2023-08-13 14:06:04
22 | 21,7801804847925160451,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""下雨不hot""}",2023-08-13 14:10:41
23 | 22,4186719283759563990,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\08b1106d78b4dfba0755ff2aa09d2438.dat"", ""msg"": ""图片""}",2023-08-13 14:11:06
24 | 23,3521100079205966470,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好像空格键""}",2023-08-13 14:11:07
25 | 24,1617187845718151808,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""很想按下去""}",2023-08-13 14:11:18
26 | 25,8345209503359387241,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=004f3c0cf5259ab214d9dabacc337c60&filekey=30350201010421301f02020106040253480410004f3c0cf5259ab214d9dabacc337c600203059dc6040d00000004627466730000000131&hy=SH&storeid=32303232303831393136353732363030306462663664303030303030303030376464613030623030303030313036&bizid=1023"", ""msg"": ""表情""}",2023-08-13 14:13:06
27 | 26,8573385017731442447,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""不行你干脆回家学习""}",2023-08-13 14:13:12
28 | 27,7076070838462900480,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈""}",2023-08-13 14:13:13
29 | 28,5591518902965619475,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""这是要去哪玩""}",2023-08-13 14:13:34
30 | 29,8086920793435315850,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""来看个电影""}",2023-08-13 14:13:51
31 | 30,1814934139736232801,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""看什么!""}",2023-08-13 14:14:49
32 | 31,3677613846620059746,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""emmm""}",2023-08-13 14:14:56
33 | 32,3409906469387839600,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""我也在纠结要不要出门看个电影""}",2023-08-13 14:15:03
34 | 33,5709723437718492173,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好像是孤注一掷""}",2023-08-13 14:15:07
35 | 34,7775547367225903484,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""但是感觉像反诈电影""}",2023-08-13 14:15:13
36 | 35,6258879459298954092,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈好像是""}",2023-08-13 14:15:14
--------------------------------------------------------------------------------
/data/res_csv/pt/dataset_info.json:
--------------------------------------------------------------------------------
1 | {"wechat-pt":{
2 | "file_name": "./pt-my.json",
3 | "columns": {
4 | "prompt": "c"
5 | }
6 | }}
--------------------------------------------------------------------------------
/data/res_csv/sft/dataset_info-with-his.json:
--------------------------------------------------------------------------------
1 | {
2 | "wechat-sft": {
3 | "file_name": "./sft-my.json",
4 | "columns": {
5 | "prompt": "instruction",
6 | "response": "output",
7 | "history": "history"
8 | }
9 | }
10 | }
--------------------------------------------------------------------------------
/data/res_csv/sft/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "wechat-sft": {
3 | "file_name": "./sft-my.json",
4 | "columns": {
5 | "prompt": "instruction",
6 | "response": "output"
7 | }
8 | }
9 | }
10 |
11 |
12 |
--------------------------------------------------------------------------------
/data/test_data.json:
--------------------------------------------------------------------------------
1 | {
2 | "questions": [
3 | [
4 | "吃了吗?",
5 | "吃的什么啊",
6 | "好吃吗",
7 | "多少钱啊",
8 | "可以请我吃吗"
9 | ],
10 | [
11 | "你多大了?"
12 | ],
13 | [
14 | "你有什么爱好吗?"
15 | ],
16 | [
17 | "你的理想是什么?",
18 | "你觉得你离你的理想还有多远?"
19 | ],
20 | [
21 | "你最近在忙什么?",
22 | "工作/学习顺利吗?",
23 | "有什么有趣的事情发生吗?"
24 | ],
25 | [
26 | "你喜欢看什么类型的电影?",
27 | "最近看过什么好看的电影吗?",
28 | "你最喜欢的电影是什么?"
29 | ],
30 | [
31 | "你平时喜欢听什么音乐?",
32 | "有推荐的歌手或乐队吗?",
33 | "最近有喜欢的歌曲吗?"
34 | ],
35 | [
36 | "你喜欢旅游吗?",
37 | "去过哪些地方?",
38 | "最喜欢的旅游地是哪里?"
39 | ],
40 | [
41 | "你喜欢读书吗?",
42 | "最近在读什么书?",
43 | "最喜欢的书是哪本?"
44 | ],
45 | [
46 | "你平时喜欢运动吗?",
47 | "喜欢做哪些运动?",
48 | "有固定去锻炼吗?"
49 | ],
50 | [
51 | "周末一般都做些什么?",
52 | "有没有什么特别的计划?",
53 | "周末喜欢宅在家还是出去玩?"
54 | ],
55 | [
56 | "你喜欢宠物吗?",
57 | "有养宠物吗?",
58 | "最喜欢什么动物?"
59 | ],
60 | [
61 | "你喜欢吃什么类型的食物?",
62 | "有推荐的餐厅吗?",
63 | "最喜欢的菜是什么?"
64 | ],
65 | [
66 | "你喜欢什么样的天气?",
67 | "最喜欢的季节是哪一个?",
68 | "你觉得今天的天气怎么样?"
69 | ],
70 | [
71 | "你有看电视剧的习惯吗?",
72 | "最近在追哪部剧?",
73 | "最喜欢的电视剧是哪部?"
74 | ],
75 | [
76 | "你喜欢玩游戏吗?",
77 | "最近在玩什么游戏?",
78 | "有推荐的好玩的游戏吗?"
79 | ],
80 | [
81 | "你会做饭吗?",
82 | "平时喜欢做哪些菜?",
83 | "有没有特别拿手的菜?"
84 | ],
85 | [
86 | "你喜欢购物吗?",
87 | "最近买了什么新东西?",
88 | "有推荐的购物网站或店铺吗?"
89 | ],
90 | [
91 | "你平时怎么放松自己?",
92 | "有特别的解压方式吗?",
93 | "最喜欢的放松活动是什么?"
94 | ],
95 | [
96 | "你喜欢和朋友出去玩吗?",
97 | "平时会和朋友去哪玩?",
98 | "最近有没有和朋友聚会的计划?"
99 | ],
100 | [
101 | "你喜欢喝咖啡还是茶?",
102 | "有没有特别喜欢的咖啡馆或茶馆?",
103 | "最喜欢的饮品是什么?"
104 | ],
105 | [
106 | "你有兄弟姐妹吗?",
107 | "和他们关系怎么样?",
108 | "经常联系吗?"
109 | ],
110 | [
111 | "你喜欢读什么类型的杂志?",
112 | "最近有看什么有趣的文章吗?",
113 | "有订阅的杂志吗?"
114 | ],
115 | [
116 | "你喜欢看体育比赛吗?",
117 | "最喜欢的运动项目是什么?",
118 | "有没有特别支持的球队或运动员?"
119 | ],
120 | [
121 | "你会说其他语言吗?",
122 | "最想学的语言是什么?",
123 | "学习语言有什么技巧吗?"
124 | ],
125 | [
126 | "你对科技产品感兴趣吗?",
127 | "最近有没有关注什么新科技?",
128 | "最喜欢的电子产品是什么?"
129 | ],
130 | [
131 | "你喜欢喝什么样的饮料?",
132 | "有没有自己调饮料的习惯?",
133 | "最喜欢的饮品品牌是什么?"
134 | ],
135 | [
136 | "你平时用社交媒体吗?",
137 | "常用哪些平台?",
138 | "在社交媒体上做什么?"
139 | ],
140 | [
141 | "你对艺术感兴趣吗?",
142 | "最喜欢的艺术家是谁?",
143 | "有去过哪些艺术展览?"
144 | ],
145 | [
146 | "你喜欢DIY吗?",
147 | "平时做些什么手工?",
148 | "有没有完成的作品可以分享?"
149 | ],
150 | [
151 | "你喜欢种植植物吗?",
152 | "有养什么植物?",
153 | "最喜欢的植物是什么?"
154 | ],
155 | [
156 | "你喜欢拍照吗?",
157 | "喜欢拍什么样的照片?",
158 | "有没有用什么特别的摄影设备?"
159 | ],
160 | [
161 | "你喜欢听播客吗?",
162 | "常听哪些主题的播客?",
163 | "有没有推荐的播客?"
164 | ],
165 | [
166 | "你对历史感兴趣吗?",
167 | "最喜欢哪个历史时期?",
168 | "有没有特别喜欢的历史人物?"
169 | ],
170 | [
171 | "你喜欢画画吗?",
172 | "平时画什么类型的画?",
173 | "有参加过画展吗?"
174 | ],
175 | [
176 | "你喜欢写作吗?",
177 | "平时写什么类型的文章?",
178 | "有没有发表过作品?"
179 | ],
180 | [
181 | "你喜欢钓鱼吗?",
182 | "平时去哪里钓鱼?",
183 | "有没有钓到过什么大鱼?"
184 | ],
185 | [
186 | "你喜欢露营吗?",
187 | "平时会去哪里露营?",
188 | "有没有什么难忘的露营经历?"
189 | ],
190 | [
191 | "你喜欢摄影吗?",
192 | "最喜欢拍什么题材?",
193 | "有没有特别喜欢的摄影师?"
194 | ],
195 | [
196 | "你喜欢喝酒吗?",
197 | "喜欢什么类型的酒?",
198 | "有没有推荐的酒吧或品牌?"
199 | ],
200 | [
201 | "你喜欢滑雪吗?",
202 | "平时去哪里滑雪?",
203 | "有没有什么滑雪技巧分享?"
204 | ],
205 | [
206 | "你喜欢海边还是山里?",
207 | "最喜欢去哪个地方度假?",
208 | "有没有什么特别推荐的景点?"
209 | ],
210 | [
211 | "你喜欢参加音乐节吗?",
212 | "参加过哪些音乐节?",
213 | "最喜欢的音乐节是哪一个?"
214 | ],
215 | [
216 | "你喜欢跑步吗?",
217 | "平时跑多长距离?",
218 | "有没有参加过马拉松?"
219 | ],
220 | [
221 | "你喜欢参加聚会吗?",
222 | "平时和朋友聚会做什么?",
223 | "有没有什么有趣的聚会游戏?"
224 | ],
225 | [
226 | "你喜欢收集东西吗?",
227 | "收集什么类型的物品?",
228 | "有没有什么特别的收藏?"
229 | ]
230 | ]
231 | }
--------------------------------------------------------------------------------
/ds_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "zero_optimization": {
14 | "stage": 2,
15 | "allgather_partitions": true,
16 | "allgather_bucket_size": 5e8,
17 | "overlap_comm": true,
18 | "reduce_scatter": true,
19 | "reduce_bucket_size": 5e8,
20 | "contiguous_gradients": true
21 | },
22 | "gradient_accumulation_steps": "auto",
23 | "gradient_clipping": "auto",
24 | "steps_per_print": 2000,
25 | "train_batch_size": "auto",
26 | "train_micro_batch_size_per_gpu": "auto",
27 | "wall_clock_breakdown": false
28 | }
--------------------------------------------------------------------------------
/img/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/1.png
--------------------------------------------------------------------------------
/img/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/2.png
--------------------------------------------------------------------------------
/img/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/3.png
--------------------------------------------------------------------------------
/img/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/4.jpg
--------------------------------------------------------------------------------
/img/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/5.png
--------------------------------------------------------------------------------
/make_dataset/blocked_words.json:
--------------------------------------------------------------------------------
1 | {
2 | "blocked_words": [
3 | "例如 姓名",
4 | "例如 地址",
5 | "//....."
6 | ]
7 | }
--------------------------------------------------------------------------------
/make_dataset/csv_to_json-单句回答.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 |
5 | import pandas as pd
6 | from collections import deque
7 |
8 | csv_folder = './data/csv'
9 | # csv_folder = './data/test'
10 | print(f'当前处理目录{csv_folder}')
11 |
12 |
13 | def handle_pt_csv(csvfile):
14 | chat_df = pd.read_csv(csvfile)
15 | # 选择type_name为文本的行、is_sender为1的行
16 | chat_df = chat_df[chat_df['type_name'] == '文本']
17 | chat_df = chat_df[chat_df['is_sender'] == 1]
18 | # 对每一行的content进行处理 转为dict 再取'msg'字段
19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg'])
20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行
21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')]
22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')]
23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')]
24 | chat_df = chat_df[~chat_df['content'].str.contains('http')]
25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')]
26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')]
27 |
28 | # 纯content
29 | chat_df = chat_df['content']
30 | chat_df = chat_df.dropna()
31 |
32 | return chat_df
33 |
34 |
35 | def make_pt_dataset():
36 | csv_res = []
37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
38 | for chat_obj_folder in os.listdir(csv_folder):
39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
40 | for csvfile in os.listdir(chat_obj_folder_path):
41 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
42 | chat_df = handle_pt_csv(csvfile_path)
43 | csv_res.append(chat_df)
44 |
45 | csv_res = pd.concat(csv_res)
46 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c
47 |
48 | csv_res.to_json('./data/res_csv/pt-my.json', orient='records', force_ascii=False)
49 |
50 |
51 | def handle_sft_csv(csvfile):
52 | chat_df = pd.read_csv(csvfile)
53 | blocked_words = json.load(open('./make_dataset/blocked_words.json', encoding='utf-8'))['blocked_words']
54 | # 选择type_name为文本的行、is_sender为1的行
55 | # 需要保留的type_name字段名
56 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序']
57 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)]
58 |
59 | # 对每一行的content进行处理 转为dict 再取'msg'字段
60 | chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg'])
61 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 用循环删除
62 | for i in chat_df.index:
63 | if chat_df.loc[i, 'type_name'] == '文本':
64 | if ('1\d{10}' in chat_df.loc[i, 'content'] or
65 | '\d{18}' in chat_df.loc[i, 'content'] or
66 | '\w+@\w+' in chat_df.loc[i, 'content'] or
67 | 'http' in chat_df.loc[i, 'content'] or
68 | r'\\xa0' in chat_df.loc[i, 'content'] or
69 | r'\\u' in chat_df.loc[i, 'content']):
70 | chat_df = chat_df.drop(index=i)
71 | continue
72 | for blocked_word in blocked_words:
73 | if blocked_word in chat_df.loc[i, 'content']:
74 | chat_df = chat_df.drop(index=i)
75 | break
76 | else:
77 | # content赋值为空
78 | chat_df.loc[i, 'content'] = ''
79 |
80 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']]
81 | chat_df = chat_df.dropna()
82 |
83 | # 时间格式 2021-07-07 10:27:23
84 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始
85 | # CreateTime字段保留最后的CreateTime
86 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime'])
87 | type_list.remove('文本')
88 | skip_list = type_list
89 | res_df = []
90 | last_is_sender = chat_df.iloc[0]['is_sender']
91 | last_content: str = chat_df.iloc[0]['content']
92 | last_CreateTime = chat_df.iloc[0]['CreateTime']
93 | # 超时处理 半天没说话就重新开始
94 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset
95 |
96 | # 遇到图片 连接 直接封装成一个句子
97 | for i, row in chat_df.iterrows():
98 | if row['type_name'] in skip_list:
99 | if last_content != '':
100 | if last_content[-1] == ',':
101 | last_content = last_content[:-1] + '。'
102 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
103 | last_content += '。'
104 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
105 | last_CreateTime = row['CreateTime']
106 | last_content = ''
107 | # cut表示被skip字段截断
108 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']})
109 | continue
110 | if last_content == '': # 重新开始
111 | last_content = row['content']
112 | last_is_sender = row['is_sender']
113 | last_CreateTime = row['CreateTime']
114 | continue
115 | if row['is_sender'] == last_is_sender:
116 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'):
117 | # 如果超时 前面的添加到res_df 并重新开始
118 | if last_content[-1] == ',':
119 | last_content = last_content[:-1] + '。'
120 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
121 | last_content += '。'
122 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
123 | last_content = row['content']
124 | last_CreateTime = row['CreateTime']
125 | continue
126 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号
127 | if last_content[-1] not in ['。', '!', '?', '…', ',']:
128 | last_content += ','
129 | last_content = last_content + row['content']
130 | last_CreateTime = row['CreateTime']
131 | else:
132 | if last_content[-1] == ',':
133 | last_content = last_content[:-1] + '。'
134 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
135 | last_content += '。'
136 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
137 | last_is_sender = row['is_sender']
138 | last_content = row['content']
139 | last_CreateTime = row['CreateTime']
140 | res_df = pd.DataFrame(res_df)
141 | return res_df
142 |
143 |
144 | def make_sft_dataset():
145 |
146 | # [
147 | # {
148 | # "instruction": "用户指令(必填)",
149 | # "input": "用户输入(选填)",
150 | # "output": "模型回答(必填)",
151 | # "system": "系统提示词(选填)",
152 | # "history": [
153 | # ["第一轮指令(选填)", "第一轮回答(选填)"],
154 | # ["第二轮指令(选填)", "第二轮回答(选填)"]
155 | # ]
156 | # }
157 | # ]
158 |
159 | csv_concat = []
160 | csv_res = []
161 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
162 | for chat_obj_folder in os.listdir(csv_folder):
163 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
164 | for csvfile in os.listdir(chat_obj_folder_path):
165 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
166 | chat_df = handle_sft_csv(csvfile_path)
167 | csv_concat.append(chat_df)
168 |
169 | csv_concat = pd.concat(csv_concat)
170 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01
171 | # 相差超过1小时的时间戳分为不同的对话
172 | # temp_res为一个长度为2的队列
173 | temp_res = deque(maxlen=2)
174 | # 6种情况
175 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理
176 | # temp_res 有0 遇到0清空队列再入队 遇到1相差超过1小时清空队列 没有相差一小时入队再全部出队 遇到cut清空队列
177 |
178 | # 选最长的做为问题的结果?
179 |
180 | for i, row in csv_concat.iterrows():
181 | if len(temp_res) == 0:
182 | if row['content'] == 'cut':
183 | continue
184 | if row['is_sender'] == 0:
185 | temp_res.append(row['content'])
186 | last_CreateTime = row['CreateTime']
187 | else:
188 | continue
189 | elif len(temp_res) == 1:
190 | if row['content'] == 'cut':
191 | temp_res.clear()
192 | last_CreateTime = row['CreateTime']
193 | elif row['is_sender'] == 0:
194 | # 遇到0 清空队列再入队
195 | temp_res.clear()
196 | temp_res.append(row['content'])
197 | last_CreateTime = row['CreateTime']
198 | else:
199 | if row['CreateTime'] - last_CreateTime > pd.Timedelta('1h'):
200 | # 相差超过1小时清空队列
201 | temp_res.clear()
202 | last_CreateTime = row['CreateTime']
203 | else:
204 | # 没有相差一小时入队再全部出队
205 | temp_res.append(row['content'])
206 | temp_output_list = temp_res[1].split(',')
207 | output = max(temp_output_list, key=len)# 只选选最长的回答作为最终数据
208 | if output[-1] == '。':
209 | output = output[:-1]
210 | csv_res.append({'instruction': temp_res[0], 'output': output})
211 | temp_res.clear()
212 | last_CreateTime = row['CreateTime']
213 |
214 | csv_res_df = pd.DataFrame(csv_res)
215 | print(f'数据量:{csv_res_df.shape[0]}')
216 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False)
217 |
218 |
219 | if __name__ == '__main__':
220 | # make_pt_dataset()
221 | make_sft_dataset()
222 |
--------------------------------------------------------------------------------
/make_dataset/csv_to_json-单句多轮.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 |
5 | import pandas as pd
6 | from collections import deque
7 |
8 | csv_folder = './data/csv'
9 | # csv_folder = './data/test'
10 | print(f'当前处理目录{csv_folder}')
11 |
12 |
13 | def handle_pt_csv(csvfile):
14 | chat_df = pd.read_csv(csvfile)
15 | # 选择type_name为文本的行、is_sender为1的行
16 | chat_df = chat_df[chat_df['type_name'] == '文本']
17 | chat_df = chat_df[chat_df['is_sender'] == 1]
18 | # 对每一行的content进行处理 转为dict 再取'msg'字段
19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg'])
20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行
21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')]
22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')]
23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')]
24 | chat_df = chat_df[~chat_df['content'].str.contains('http')]
25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')]
26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')]
27 |
28 | # 纯content
29 | chat_df = chat_df['content']
30 | chat_df = chat_df.dropna()
31 |
32 | return chat_df
33 |
34 |
35 | def make_pt_dataset():
36 | csv_res = []
37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
38 | for chat_obj_folder in os.listdir(csv_folder):
39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
40 | for csvfile in os.listdir(chat_obj_folder_path):
41 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
42 | chat_df = handle_pt_csv(csvfile_path)
43 | csv_res.append(chat_df)
44 |
45 | csv_res = pd.concat(csv_res)
46 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c
47 |
48 | csv_res.to_json('./data/res_csv/pt.json', orient='records', force_ascii=False)
49 |
50 |
51 | def handle_sft_csv(csvfile):
52 | chat_df = pd.read_csv(csvfile)
53 | blocked_words = json.load(open('./make_dataset/blocked_words.json'))['blocked_words']
54 | # 选择type_name为文本的行、is_sender为1的行
55 | # 需要保留的type_name字段名
56 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序']
57 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)]
58 |
59 | # 对每一行的content进行处理 转为dict 再取'msg'字段
60 | chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg'])
61 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 用循环删除
62 | for i in chat_df.index:
63 | if chat_df.loc[i, 'type_name'] == '文本':
64 | if ('1\d{10}' in chat_df.loc[i, 'content'] or
65 | '\d{18}' in chat_df.loc[i, 'content'] or
66 | '\w+@\w+' in chat_df.loc[i, 'content'] or
67 | 'http' in chat_df.loc[i, 'content'] or
68 | r'\\xa0' in chat_df.loc[i, 'content'] or
69 | r'\\u' in chat_df.loc[i, 'content']):
70 | chat_df = chat_df.drop(index=i)
71 | continue
72 | for blocked_word in blocked_words:
73 | if blocked_word in chat_df.loc[i, 'content']:
74 | chat_df = chat_df.drop(index=i)
75 | break
76 | else:
77 | # content赋值为空
78 | chat_df.loc[i, 'content'] = ''
79 |
80 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']]
81 | chat_df = chat_df.dropna()
82 |
83 | # 时间格式 2021-07-07 10:27:23
84 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始
85 | # CreateTime字段保留最后的CreateTime
86 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime'])
87 | type_list.remove('文本')
88 | skip_list = type_list
89 | res_df = []
90 | last_is_sender = chat_df.iloc[0]['is_sender']
91 | last_content: str = chat_df.iloc[0]['content']
92 | last_CreateTime = chat_df.iloc[0]['CreateTime']
93 | # 超时处理 半天没说话就重新开始
94 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset
95 |
96 | # 遇到图片 连接 直接封装成一个句子
97 | for i, row in chat_df.iterrows():
98 | if '跟最终成绩差1分' in row['content']:
99 | pass
100 | if row['type_name'] in skip_list:
101 | if last_content != '':
102 | if last_content[-1] == ',':
103 | last_content = last_content[:-1] + '。'
104 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
105 | last_content += '。'
106 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
107 | last_CreateTime = row['CreateTime']
108 | last_content = ''
109 | # cut表示被skip字段截断
110 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']})
111 | continue
112 | if last_content == '': # 重新开始
113 | last_content = row['content']
114 | last_is_sender = row['is_sender']
115 | last_CreateTime = row['CreateTime']
116 | continue
117 | if row['is_sender'] == last_is_sender:
118 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'):
119 | # 如果超时 前面的添加到res_df 并重新开始
120 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
121 | res_df.append({'is_sender': last_is_sender, 'content': 'cut', 'CreateTime': last_CreateTime})
122 | last_content = row['content']
123 | last_CreateTime = row['CreateTime']
124 | continue
125 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号
126 | if last_content[-1] not in ['。', '!', '?', '…', ',']:
127 | last_content += ','
128 | last_content = last_content + row['content']
129 | last_CreateTime = row['CreateTime']
130 | else:
131 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
132 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'):
133 | res_df.append({'is_sender': last_is_sender, 'content': 'cut', 'CreateTime': last_CreateTime})
134 | last_is_sender = row['is_sender']
135 | last_content = row['content']
136 | last_CreateTime = row['CreateTime']
137 | res_df = pd.DataFrame(res_df)
138 | return res_df
139 |
140 |
141 | def make_sft_dataset():
142 |
143 | # [
144 | # {
145 | # "instruction": "用户指令(必填)",
146 | # "input": "用户输入(选填)",
147 | # "output": "模型回答(必填)",
148 | # "system": "系统提示词(选填)",
149 | # "history": [
150 | # ["第一轮指令(选填)", "第一轮回答(选填)"],
151 | # ["第二轮指令(选填)", "第二轮回答(选填)"]
152 | # ]
153 | # }
154 | # ]
155 |
156 | csv_concat = []
157 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
158 | for chat_obj_folder in os.listdir(csv_folder):
159 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
160 | for csvfile in os.listdir(chat_obj_folder_path):
161 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
162 | chat_df = handle_sft_csv(csvfile_path)
163 | csv_concat.append(chat_df)
164 |
165 | csv_concat = pd.concat(csv_concat)
166 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01
167 | # 相差超过1小时的时间戳分为不同的对话
168 | # temp_res为一个长度为2的队列
169 | temp_chat_item = deque(maxlen=2)
170 | qa_list = []
171 |
172 | # 6种情况
173 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理
174 | # temp_res 有0 遇到0清空队列再入队 遇到1入队再全部出队 遇到cut清空队列
175 |
176 | # 选最长的做为问题的结果?
177 |
178 | for i, row in csv_concat.iterrows():
179 | if len(temp_chat_item) == 0:
180 | if row['content'] == 'cut':
181 | qa_list.append('cut')
182 | continue
183 | if row['is_sender'] == 0:
184 | temp_chat_item.append(row['content'])
185 | else:
186 | continue
187 | elif len(temp_chat_item) == 1:
188 | if row['content'] == 'cut':
189 | qa_list.append('cut')
190 | temp_chat_item.clear()
191 | elif row['is_sender'] == 0:
192 | # 遇到0 清空队列再入队
193 | temp_chat_item.clear()
194 | temp_chat_item.append(row['content'])
195 | else:
196 | temp_chat_item.append(row['content'])
197 | temp_output_list = temp_chat_item[1].split(',')
198 | output = max(temp_output_list, key=len)
199 | if output[-1] == '。':
200 | output = output[:-1]
201 | qa_list.append({'instruction': temp_chat_item[0], 'output': output})
202 | # csv_res.append({'instruction': temp_chat_item[0], 'output': output, 'system': '请你扮演一名人类,不要说自己是人工智能'})
203 | temp_chat_item.clear()
204 |
205 | csv_res = []
206 | system_prompt = '请你扮演一名人类,不要说自己是人工智能或者某个模型'
207 |
208 | last_res = {'instruction': '', 'output': '',
209 | 'system': system_prompt,
210 | 'history': []}
211 | for i, qa in enumerate(qa_list):
212 | if qa == 'cut':
213 | if len(last_res['history']) == 0:
214 | continue
215 | else:
216 | if len(last_res['history']) == 1:
217 | last_res = {'instruction': last_res['history'][0][0], 'output': last_res['history'][0][1],
218 | 'system': system_prompt,'history': []}
219 | else:
220 | last_res = {'instruction': last_res['history'][-1][0], 'output': last_res['history'][-1][1],
221 | 'system': system_prompt,
222 | 'history': last_res['history'][:-1]}
223 | csv_res.append(last_res)
224 | last_res = {'instruction': '', 'output': '',
225 | 'system': system_prompt,
226 | 'history': []}
227 | else:
228 | last_res['history'].append([qa['instruction'], qa['output']])
229 |
230 |
231 | csv_res_df = pd.DataFrame(csv_res)
232 | print(f'数据量:{csv_res_df.shape[0]}')
233 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False)
234 |
235 |
236 | if __name__ == '__main__':
237 | # make_pt_dataset()
238 | make_sft_dataset()
239 |
--------------------------------------------------------------------------------
/make_dataset/csv_to_json.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import json
3 | import os
4 |
5 | import pandas as pd
6 | from collections import deque
7 |
8 | csv_folder = './data/csv'
9 | # csv_folder = './data/test'
10 | print(f'当前处理目录{csv_folder}')
11 |
12 |
13 | def handle_pt_csv(csvfile):
14 | chat_df = pd.read_csv(csvfile)
15 | # 选择type_name为文本的行、is_sender为1的行
16 | chat_df = chat_df[chat_df['type_name'] == '文本']
17 | chat_df = chat_df[chat_df['is_sender'] == 1]
18 | # 对每一行的content进行处理 转为dict 再取'msg'字段
19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg'])
20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行
21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')]
22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')]
23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')]
24 | chat_df = chat_df[~chat_df['content'].str.contains('http')]
25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')]
26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')]
27 |
28 | # 纯content
29 | chat_df = chat_df['content']
30 | chat_df = chat_df.dropna()
31 |
32 | return chat_df
33 |
34 |
35 | def make_pt_dataset():
36 | csv_res = []
37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
38 | for chat_obj_folder in os.listdir(csv_folder):
39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
40 | for csvfile in os.listdir(chat_obj_folder_path):
41 | if not csvfile.endswith('.csv'):
42 | continue
43 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
44 | chat_df = handle_pt_csv(csvfile_path)
45 | csv_res.append(chat_df)
46 |
47 | csv_res = pd.concat(csv_res)
48 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c
49 |
50 | csv_res.to_json('./data/res_csv/pt-my.json', orient='records', force_ascii=False)
51 |
52 |
53 | def handle_sft_csv(csvfile):
54 | chat_df = pd.read_csv(csvfile)
55 | blocked_words = json.load(open('./make_dataset/blocked_words.json', encoding='utf-8'))['blocked_words']
56 | # 选择type_name为文本的行、is_sender为1的行
57 | # 需要保留的type_name字段名
58 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序']
59 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)]
60 |
61 | # chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg'])
62 | chat_df['content'] = chat_df['msg']
63 |
64 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行
65 | for i in chat_df.index:
66 | if chat_df.loc[i, 'type_name'] == '文本':
67 | if ('1\d{10}' in chat_df.loc[i, 'content'] or
68 | '\d{18}' in chat_df.loc[i, 'content'] or
69 | '\w+@\w+' in chat_df.loc[i, 'content'] or
70 | 'http' in chat_df.loc[i, 'content'] or
71 | r'\\xa0' in chat_df.loc[i, 'content'] or
72 | r'\\u' in chat_df.loc[i, 'content']):
73 | chat_df = chat_df.drop(index=i)
74 | continue
75 | for blocked_word in blocked_words:
76 | if blocked_word in chat_df.loc[i, 'content']:
77 | chat_df = chat_df.drop(index=i)
78 | break
79 | else:
80 | chat_df.loc[i, 'content'] = ''
81 |
82 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']]
83 | chat_df = chat_df.dropna()
84 |
85 | # 时间格式 2021-07-07 10:27:23
86 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始
87 | # CreateTime字段保留最后的CreateTime
88 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime'])
89 | type_list.remove('文本')
90 | skip_list = type_list
91 | res_df = []
92 | last_is_sender = chat_df.iloc[0]['is_sender']
93 | last_content: str = chat_df.iloc[0]['content']
94 | last_CreateTime = chat_df.iloc[0]['CreateTime']
95 | # 超时处理 半天没说话就重新开始
96 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset
97 | # 遇到图片 连接 直接封装成一个句子
98 | for i, row in chat_df.iterrows():
99 | if row['type_name'] in skip_list:
100 | if last_content != '':
101 | if last_content[-1] == ',':
102 | last_content = last_content[:-1] + '。'
103 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
104 | last_content += '。'
105 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
106 | last_CreateTime = row['CreateTime']
107 | last_content = ''
108 | # cut表示被skip字段截断
109 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']})
110 | continue
111 | if last_content == '': # 重新开始
112 | last_content = row['content']
113 | last_is_sender = row['is_sender']
114 | last_CreateTime = row['CreateTime']
115 | continue
116 | if row['is_sender'] == last_is_sender:
117 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'):
118 | # 如果超时 前面的添加到res_df 并重新开始
119 | if last_content[-1] == ',':
120 | last_content = last_content[:-1] + '。'
121 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
122 | last_content += '。'
123 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
124 | last_content = row['content']
125 | last_CreateTime = row['CreateTime']
126 | continue
127 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号
128 | if last_content[-1] not in ['。', '!', '?', '…', ',']:
129 | last_content += ','
130 | last_content = last_content + row['content']
131 | last_CreateTime = row['CreateTime']
132 | else:
133 | if last_content[-1] == ',':
134 | last_content = last_content[:-1] + '。'
135 | elif last_content[-1] not in ['。', '!', '?', '…', '.']:
136 | last_content += '。'
137 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime})
138 | last_is_sender = row['is_sender']
139 | last_content = row['content']
140 | last_CreateTime = row['CreateTime']
141 | res_df = pd.DataFrame(res_df)
142 | return res_df
143 |
144 |
145 | def make_sft_dataset():
146 |
147 | # [
148 | # {
149 | # "instruction": "用户指令(必填)",
150 | # "input": "用户输入(选填)",
151 | # "output": "模型回答(必填)",
152 | # "system": "系统提示词(选填)",
153 | # "history": [
154 | # ["第一轮指令(选填)", "第一轮回答(选填)"],
155 | # ["第二轮指令(选填)", "第二轮回答(选填)"]
156 | # ]
157 | # }
158 | # ]
159 |
160 | csv_concat = []
161 | csv_res = []
162 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件
163 | for chat_obj_folder in os.listdir(csv_folder):
164 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder)
165 | for csvfile in os.listdir(chat_obj_folder_path):
166 | if not csvfile.endswith('.csv'):
167 | continue
168 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile)
169 | chat_df = handle_sft_csv(csvfile_path)
170 | csv_concat.append(chat_df)
171 |
172 | csv_concat = pd.concat(csv_concat)
173 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01
174 | # 相差超过1小时的时间戳分为不同的对话
175 | # temp_res为一个长度为2的队列
176 | temp_res = deque(maxlen=2)
177 | # 6种情况
178 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理
179 | # temp_res 有0 遇到0清空队列再入队 遇到1相差超过1小时清空队列 没有相差一小时入队再全部出队 遇到cut清空队列
180 |
181 | for i, row in csv_concat.iterrows():
182 | if len(temp_res) == 0:
183 | if row['content'] == 'cut':
184 | continue
185 | if row['is_sender'] == 0:
186 | temp_res.append(row['content'])
187 | last_CreateTime = row['CreateTime']
188 | else:
189 | continue
190 | elif len(temp_res) == 1:
191 | if row['content'] == 'cut':
192 | temp_res.clear()
193 | last_CreateTime = row['CreateTime']
194 | elif row['is_sender'] == 0:
195 | # 遇到0 清空队列再入队
196 | temp_res.clear()
197 | temp_res.append(row['content'])
198 | last_CreateTime = row['CreateTime']
199 | else:
200 | if row['CreateTime'] - last_CreateTime > pd.Timedelta('1h'):
201 | # 相差超过1小时清空队列
202 | temp_res.clear()
203 | last_CreateTime = row['CreateTime']
204 | else:
205 | # 没有相差一小时入队再全部出队
206 | temp_res.append(row['content'])
207 | csv_res.append({'instruction': temp_res[0], 'output': temp_res[1]})
208 | temp_res.clear()
209 | last_CreateTime = row['CreateTime']
210 |
211 |
212 | csv_res_df = pd.DataFrame(csv_res)
213 | print(f'处理后数据量:{csv_res_df.shape[0]}')
214 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False)
215 |
216 |
217 | if __name__ == '__main__':
218 | # make_pt_dataset()
219 | make_sft_dataset()
220 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "WeClone"
3 | version = "0.1.2"
4 | description = ""
5 | authors = [{ name = "xming521" }]
6 | readme = "README.md"
7 | requires-python = ">=3.9,<3.10"
8 | dependencies = ["pandas", "pydantic==2.10.6"]
9 |
10 | [dependency-groups]
11 | xcodec = ["xcodec2==0.1.3"]
12 | wx = ["pywxdump"]
13 | sparktts = [
14 | "einops>=0.8.1",
15 | "einx>=0.3.0",
16 | "numpy==1.26.4",
17 | "omegaconf>=2.3.0",
18 | "packaging>=24.2",
19 | "safetensors>=0.5.2",
20 | "soundfile>=0.12.1",
21 | "soxr>=0.5.0.post1",
22 | "torch>=2.5.1",
23 | "torchaudio>=2.5.1",
24 | "tqdm>=4.66.5",
25 | "transformers==4.45.2",
26 | ]
27 | main = ["transformers==4.45.2", "llamafactory>=0.9.2", "openai==0.28.0"]
28 |
29 | [tool.uv]
30 | conflicts = [
31 | [
32 | { group = "wx" },
33 | { group = "sparktts" },
34 | ],
35 | [
36 | { group = "wx" },
37 | { group = "main" },
38 | ],
39 | [
40 | { group = "wx" },
41 | { group = "xcodec" },
42 | ],
43 | ]
44 |
45 | [[tool.uv.index]]
46 | url = "https://pypi.tuna.tsinghua.edu.cn/simple/"
47 | default = true
48 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # LLaMA-Factory
2 | llmtuner
3 | # wechat
4 | itchat-uos==1.5.0.dev0
5 | # others
6 | pandas
7 | # chromadb
8 | # langchain
9 | openai==0.28
10 |
11 |
12 |
--------------------------------------------------------------------------------
/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_pt_args": {
3 | "stage": "pt",
4 | "dataset": "wechat-pt",
5 | "dataset_dir": "./data/res_csv/pt",
6 | "lora_target": "query_key_value",
7 | "lora_rank": 2,
8 | "lora_dropout": 0.1,
9 | "output_dir": "model_output",
10 | "overwrite_cache": true,
11 | "per_device_train_batch_size": 1,
12 | "gradient_accumulation_steps": 1,
13 | "lr_scheduler_type": "cosine",
14 | "logging_steps": 10,
15 | "save_steps": 1000,
16 | "learning_rate": 0.001,
17 | "num_train_epochs": 30,
18 | "plot_loss": true,
19 | "fp16": true
20 | },
21 | "train_sft_args": {
22 | "stage": "sft",
23 | "dataset": "wechat-sft",
24 | "dataset_dir": "./data/res_csv/sft",
25 | "lora_target": "query_key_value",
26 | "lora_rank": 4,
27 | "lora_dropout": 0.5,
28 | "overwrite_cache": true,
29 | "per_device_train_batch_size": 4,
30 | "gradient_accumulation_steps": 8,
31 | "lr_scheduler_type": "cosine",
32 | "logging_steps": 10,
33 | "save_steps": 150,
34 | "learning_rate": 0.0001,
35 | "num_train_epochs": 3,
36 | "plot_loss": true,
37 | "fp16": true
38 | },
39 | "infer_args": {
40 | "repetition_penalty": 1.2,
41 | "temperature": 0.5,
42 | "max_length": 50,
43 | "top_p": 0.65
44 | },
45 | "common_args": {
46 | "model_name_or_path": "./chatglm3-6b",
47 | "adapter_name_or_path": "./model_output",
48 | "template": "chatglm3-weclone",
49 | "finetuning_type": "lora",
50 | "trust_remote_code": true
51 | },
52 | "_comment": "adapter_name_or_path同时做为train_sft_args的output_dir "
53 | }
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/src/__init__.py
--------------------------------------------------------------------------------
/src/api_service.py:
--------------------------------------------------------------------------------
1 | import os
2 | import uvicorn
3 | from llamafactory.chat import ChatModel
4 | from llamafactory.api.app import create_app
5 | from template import template_register
6 | from utils.config import load_config
7 |
8 | config = load_config('api_service')
9 |
10 |
11 | template_register()
12 |
13 |
14 | def main():
15 | chat_model = ChatModel(config)
16 | app = create_app(chat_model)
17 | print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8005)))
18 | uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8005)), workers=1)
19 |
20 |
21 | if __name__ == "__main__":
22 | main()
23 |
--------------------------------------------------------------------------------
/src/cli_demo.py:
--------------------------------------------------------------------------------
1 | from llamafactory.chat import ChatModel
2 | from llamafactory.extras.misc import torch_gc
3 |
4 |
5 | try:
6 | import platform
7 |
8 | if platform.system() != "Windows":
9 | import readline # noqa: F401
10 | except ImportError:
11 | print("Install `readline` for a better experience.")
12 |
13 |
14 | def main():
15 | chat_model = ChatModel()
16 | messages = []
17 | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
18 |
19 | while True:
20 | try:
21 | query = input("\nUser: ")
22 | except UnicodeDecodeError:
23 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.")
24 | continue
25 | except Exception:
26 | raise
27 |
28 | if query.strip() == "exit":
29 | break
30 |
31 | if query.strip() == "clear":
32 | messages = []
33 | torch_gc()
34 | print("History has been removed.")
35 | continue
36 |
37 | messages.append({"role": "user", "content": query})
38 | print("Assistant: ", end="", flush=True)
39 |
40 | response = ""
41 | for new_text in chat_model.stream_chat(messages):
42 | print(new_text, end="", flush=True)
43 | response += new_text
44 | print()
45 | messages.append({"role": "assistant", "content": response})
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/src/evaluate.py:
--------------------------------------------------------------------------------
1 | from llamafactory.eval.evaluator import Evaluator
2 |
3 |
4 | def main():
5 | evaluator = Evaluator()
6 | evaluator.eval()
7 |
8 |
9 | if __name__ == "__main__":
10 | main()
11 |
--------------------------------------------------------------------------------
/src/export_model.py:
--------------------------------------------------------------------------------
1 | from llamafactory.train.tuner import export_model
2 |
3 |
4 | def main():
5 | export_model()
6 |
7 |
8 | if __name__ == "__main__":
9 | main()
10 |
--------------------------------------------------------------------------------
/src/template.py:
--------------------------------------------------------------------------------
1 | from llamafactory.data.formatter import FunctionFormatter, StringFormatter, ToolFormatter, EmptyFormatter
2 | from llamafactory.data.template import register_template
3 |
4 | default_prompt = "请你扮演一名人类,不要说自己是人工智能"
5 |
6 |
7 | def template_register():
8 | register_template(
9 | name="chatglm3-weclone",
10 | default_system=(
11 | default_prompt
12 | ),
13 | format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
14 | format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
15 | format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
16 | format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"),
17 | format_observation=StringFormatter(
18 | slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
19 | ),
20 | format_tools=ToolFormatter(tool_format="glm4"),
21 | format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
22 | stop_words=["<|user|>", "<|observation|>"],
23 | efficient_eos=True,
24 | )
25 |
--------------------------------------------------------------------------------
/src/test_model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import sys
5 | import openai
6 | sys.path.append(os.getcwd())
7 | import tqdm
8 | from src.template import default_prompt
9 | from tqdm import tqdm
10 |
11 |
12 | config = {
13 | 'default_prompt': default_prompt,
14 | 'model': 'gpt-3.5-turbo',
15 | 'history_len': 15,
16 | }
17 |
18 | config = type('Config', (object,), config)()
19 |
20 | openai.api_key = '''sk-test'''
21 | openai.api_base = "http://127.0.0.1:8000/v1"
22 |
23 |
24 | def handler_text(content: str, history: [], config):
25 |
26 | messages = [{"role": "system", "content": f'{config.default_prompt}'}]
27 | for item in history:
28 | messages.append(item)
29 | messages.append({"role": "user", "content": content})
30 | history.append({"role": "user", "content": content})
31 | try:
32 | response = openai.ChatCompletion.create(model=config.model,
33 | messages=messages,
34 | max_tokens=50)
35 | except openai.APIError as e:
36 | history.pop()
37 | return 'AI接口出错,请重试\n' + str(e)
38 |
39 | resp = str(response.choices[0].message.content)
40 | resp = resp.replace('\n ', '')
41 | history.append({"role": "assistant", "content": resp})
42 | return resp
43 |
44 |
45 | def main():
46 | test_list = json.loads(open('data/test_data.json').read())['questions']
47 | res = []
48 | for questions in tqdm(test_list, desc=' Testing...'):
49 | history = []
50 | for q in questions:
51 | answer = handler_text(q, history=history, config=config)
52 | res.append(history)
53 |
54 | res_file = open('test_result-my.txt', 'w')
55 | for r in res:
56 | for i in r:
57 | res_file.write(i['content'] + '\n')
58 | res_file.write('\n')
59 |
60 |
61 | if __name__ == '__main__':
62 | main()
63 |
--------------------------------------------------------------------------------
/src/train_pt.py:
--------------------------------------------------------------------------------
1 | from llamafactory.train.tuner import run_exp
2 | from utils.config import load_config
3 |
4 | config = load_config('train_pt')
5 | run_exp(config)
6 |
--------------------------------------------------------------------------------
/src/train_sft.py:
--------------------------------------------------------------------------------
1 | from llamafactory.train.tuner import run_exp
2 | from template import template_register
3 | from utils.config import load_config
4 |
5 | config = load_config(arg_type='train_sft')
6 |
7 | template_register()
8 |
9 | run_exp(config)
10 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/config.py:
--------------------------------------------------------------------------------
1 | import json
2 | from .utils import dict_to_argv
3 | import sys
4 |
5 |
6 | def load_config(arg_type: str):
7 | with open('./settings.json', 'r') as f:
8 | config: dict = json.load(f)
9 | if arg_type == 'web_demo' or arg_type == 'api_service':
10 | # infer_args和common_args求并集
11 | config = {**config['infer_args'], **config['common_args']}
12 | elif arg_type == 'train_pt':
13 | config = {**config['train_pt_args'], **config['common_args']}
14 | elif arg_type == 'train_sft':
15 | config = {**config['train_sft_args'], **config['common_args']}
16 | else:
17 | raise ValueError('暂不支持的类型')
18 |
19 | if 'train' in arg_type:
20 | config['output_dir'] = config['adapter_name_or_path']
21 | config.pop('adapter_name_or_path')
22 | config['do_train'] = True
23 |
24 | sys.argv += dict_to_argv(config)
25 |
26 | return config
27 |
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | def dict_to_argv(d):
2 | argv = []
3 | for k, v in d.items():
4 | argv.append('--' + k)
5 | if v is not None:
6 | argv.append(str(v))
7 | return argv
8 |
9 |
10 |
--------------------------------------------------------------------------------
/src/web_demo.py:
--------------------------------------------------------------------------------
1 | from llamafactory.webui.interface import create_web_demo
2 | from template import template_register
3 | from utils.config import load_config
4 |
5 | config = load_config('web_demo')
6 |
7 | template_register()
8 |
9 | def main():
10 | demo = create_web_demo()
11 | demo.queue()
12 | demo.launch(server_name="0.0.0.0", share=True, inbrowser=True)
13 |
14 |
15 | if __name__ == "__main__":
16 | main()
17 |
--------------------------------------------------------------------------------
/src/wechat_bot/handler/text.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import openai
4 |
5 |
6 | log = logging.getLogger('text')
7 |
8 |
9 | def handler_text(content: str, history: [], config):
10 | # todo 收到/clear清理历史记录
11 | # try:
12 | # history.clear()
13 | # return '清理完毕!'
14 | # except KeyError:
15 | # return '不存在消息记录,无需清理'
16 |
17 | messages = [{"role": "system", "content": f'{config.default_prompt}'}]
18 | for item in history:
19 | messages.append(item)
20 | messages.append({"role": "user", "content": content})
21 | history.append({"role": "user", "content": content})
22 | try:
23 | response = openai.ChatCompletion.create(model=config.model,
24 | messages=messages,
25 | max_tokens=50)
26 | except openai.APIError as e:
27 | log.error(e)
28 | history.pop()
29 | return 'AI接口出错,请重试\n' + str(e)
30 |
31 | resp = str(response.choices[0].message.content)
32 | resp = resp.replace('\n ', '')
33 | history.append({"role": "assistant", "content": resp})
34 | return resp
35 |
--------------------------------------------------------------------------------
/src/wechat_bot/main.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import signal
4 | import sys
5 | import time
6 | import xml.etree.ElementTree as ET
7 | sys.path.append(os.getcwd())
8 | import openai
9 | import requests
10 |
11 | import itchat
12 | from handler.text import handler_text
13 | from itchat import utils
14 | from itchat.content import *
15 | from src.template import default_prompt
16 | import logging
17 |
18 | # logging.basicConfig(level=logging.INFO)
19 | log = logging.getLogger('main')
20 |
21 | config = {
22 | 'default_prompt': default_prompt,
23 | 'model': 'gpt-3.5-turbo',
24 | 'history_len': 15,
25 | }
26 |
27 |
28 | config = type('Config', (object,), config)()
29 |
30 |
31 | def stop_program(signal, frame):
32 | log.info('WeChatbot Closing Save some data')
33 | itchat.dump_login_status()
34 | sys.exit(0)
35 |
36 |
37 | signal.signal(signal.SIGTERM, stop_program)
38 |
39 |
40 | class WeChatGPT:
41 |
42 | def __init__(self):
43 | itchat.auto_login(enableCmdQR=2, hotReload=True, statusStorageDir='./cookie.bin')
44 |
45 | self.history = {}
46 | self.prompts = {}
47 | openai.api_key = '''sk-test'''
48 | openai.api_base = "http://127.0.0.1:8000/v1"
49 |
50 | log.info("init successful!")
51 |
52 | def handler_history(self, msg):
53 | self.history.setdefault(msg.user.userName, [])
54 | history = self.history[msg.user.userName]
55 | need_remove_len = len(history) - config.history_len
56 | if need_remove_len > 0:
57 | for i in range(need_remove_len):
58 | # 必须出一对
59 | history.pop(0)
60 | history.pop(0)
61 | return history
62 |
63 | def reply(self, msg):
64 | if time.time() - msg.CreateTime > 5:
65 | return None
66 | res = handler_text(content=msg.text, history=self.handler_history(msg), config=config)
67 | res = res.split(',')
68 | res[-1] = res[-1].replace('。', '')
69 | if res[0] == '':
70 | res[0] = '机器人他无语了'
71 | for r in res:
72 | msg.user.send(r)
73 | time.sleep(2.2)
74 |
75 | def run(self):
76 | @itchat.msg_register(FRIENDS)
77 | def add_friend(msg):
78 | """自动同意好友"""
79 | root = ET.fromstring(msg.content)
80 | ticket = root.get('ticket')
81 | # itchat.accept_friend(msg.user.userName, ticket)
82 |
83 | @itchat.msg_register(TEXT)
84 | def friend(msg):
85 | """处理私聊消息"""
86 | log.info(f"{msg.user.NickName}: {msg.text}")
87 | self.reply(msg)
88 |
89 | @itchat.msg_register(TEXT, isGroupChat=True)
90 | def groups(msg):
91 | """处理群聊消息"""
92 | if msg.isAt:
93 | self.reply(msg)
94 |
95 | itchat.run(debug=True)
96 |
97 |
98 | if __name__ == "__main__":
99 |
100 | try:
101 | weChatGPT = WeChatGPT()
102 | weChatGPT.run()
103 | except KeyboardInterrupt:
104 | log.info("bye!")
105 |
--------------------------------------------------------------------------------