├── .gitignore ├── README.md ├── WeClone-audio ├── README.md └── src │ ├── Llasa │ ├── infer.py │ └── text_to_speech.py │ ├── SparkTTS.py │ ├── __init__.py │ ├── get_sample_audio.py │ ├── infer.py │ ├── sample.wav │ └── server未完工 │ ├── .env.example │ ├── handle_text.py │ ├── requirements.txt │ ├── server.py │ ├── tts_handler.py │ └── utils.py ├── data ├── example_chat.csv ├── res_csv │ ├── pt │ │ └── dataset_info.json │ └── sft │ │ ├── dataset_info-with-his.json │ │ └── dataset_info.json └── test_data.json ├── ds_config.json ├── img ├── 1.png ├── 2.png ├── 3.png ├── 4.jpg └── 5.png ├── make_dataset ├── blocked_words.json ├── csv_to_json-单句回答.py ├── csv_to_json-单句多轮.py └── csv_to_json.py ├── pyproject.toml ├── requirements.txt ├── settings.json └── src ├── __init__.py ├── api_service.py ├── cli_demo.py ├── evaluate.py ├── export_model.py ├── template.py ├── test_model.py ├── train_pt.py ├── train_sft.py ├── utils ├── __init__.py ├── config.py └── utils.py ├── web_demo.py └── wechat_bot ├── handler └── text.py └── main.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | weclone_archive-my/ 3 | **/pycache/ 4 | events.out.tfevents.* 5 | 归档/ 6 | *.pt 7 | *.npz 8 | *nohup.out 9 | *log.txt 10 | *cookie.bin 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | 143 | data/csv.zip 144 | LLaMA-Factory 145 | chatglm3-6b 146 | cache 147 | archive 148 | model_output* 149 | data/test 150 | .vscode 151 | *-my.* 152 | *.csv 153 | test.* 154 | *users.json 155 | WeClone-audio/src/output*.wav 156 | WeClone-audio/uv.lock 157 | Spark-TTS-0.5B/ 158 | uv.lock 159 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![download](https://github.com/user-attachments/assets/5842e84e-004f-4afd-9373-af64e9575b78) 2 | 3 | ## 核心功能✨ 4 | - 💬 使用微信聊天记录微调LLM 5 | - 🎙️ 使用微信语音消息➕0.5B大模型实现高质量声音克隆 👉[WeClone-audio](https://github.com/xming521/WeClone/tree/master/WeClone-audio) 6 | - 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身 7 | 8 | ## 特性与说明📋 9 | 10 | > [!TIP] 11 | > 新特性:[WeClone-audio](https://github.com/xming521/WeClone/tree/master/WeClone-audio) 模块,支持对微信语音进行克隆。 12 | 13 | 14 | > [!IMPORTANT] 15 | > 微调LLM最终效果很大程度取决于聊天数据的数量和质量 16 | 17 | ### 硬件要求 18 | 19 | 目前项目默认使用chatglm3-6b模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法,占用显存更少,需要自行修改模板的system提示词等相关配置。 20 | 21 | 需要显存的估算值: 22 | | 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B | 23 | | ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | 24 | | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | 25 | | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | 26 | | Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | 27 | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | 28 | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | 29 | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | 30 | 31 | 32 | ### 环境搭建 33 | 建议使用 [uv](https://docs.astral.sh/uv/),这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含xcodec(音频克隆)功能的依赖: 34 | ```bash 35 | git clone https://github.com/xming521/WeClone.git 36 | cd WeClone 37 | uv venv .venv --python=3.9 38 | source .venv/bin/activate 39 | uv pip install --group main -e . 40 | ``` 41 | 42 | > [!NOTE] 43 | > 训练以及推理相关配置统一在文件[settings.json](settings.json) 44 | 45 | 46 | ### 数据准备 47 | 48 | 请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人或群聊,然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./data`目录即可,也就是不同人聊天记录的文件夹一起放在 `./data/csv`。 示例数据位于[data/example_chat.csv](data/example_chat.csv)。 49 | 50 | ### 数据预处理 51 | 52 | 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还提供了一个禁用词词库[blocked_words](make_dataset/blocked_words.json),可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。 53 | 执行 `./make_dataset/csv_to_json.py` 脚本对数据进行处理。 54 | 55 | 在同一人连续回答多句的情况下,有三种处理方式: 56 | | 文件 | 处理方式 | 57 | | --- | --- | 58 | | csv_to_json.py | 用逗号连接 | 59 | | csv_to_json-单句回答.py(已废弃) | 只选择最长的回答作为最终数据 | 60 | | csv_to_json-单句多轮.py | 放在了提示词的'history'中 | 61 | 62 | ### 模型下载 63 | 64 | 首选在Hugging Face下载[ChatGLM3](https://huggingface.co/THUDM/chatglm3-6b) 模型。如果您在 Hugging Face 模型的下载中遇到了问题,可以通过下述方法使用魔搭社区,后续训练推理都需要先执行`export USE_MODELSCOPE_HUB=1`来使用魔搭社区的模型。 65 | 由于模型较大,下载过程比较漫长请耐心等待。 66 | 67 | ```bash 68 | export USE_MODELSCOPE_HUB=1 # Windows 使用 `set USE_MODELSCOPE_HUB=1` 69 | git lfs install 70 | git clone https://www.modelscope.cn/ZhipuAI/chatglm3-6b.git 71 | ``` 72 | 魔搭社区的`modeling_chatglm.py`文件需要更换为Hugging Face的 73 | 74 | ### 配置参数并微调模型 75 | 76 | - (可选)修改 [settings.json](settings.json)选择本地下载好的其他模型。 77 | 78 | - 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。 79 | - 可以根据自己数据集的数量和质量修改`num_train_epochs`、`lora_rank`、`lora_dropout`等参数。 80 | 81 | #### 单卡训练 82 | 83 | 运行 `src/train_sft.py` 进行sft阶段微调,本人loss只降到了3.5左右,降低过多可能会过拟合,我使用了大概2万条整合后的有效数据。 84 | 85 | ```bash 86 | python src/train_sft.py 87 | ``` 88 | 89 | #### 多卡训练 90 | 91 | ```bash 92 | uv pip install deepspeed 93 | deepspeed --num_gpus=使用显卡数量 src/train_sft.py 94 | ``` 95 | 96 | 97 | ### 使用浏览器demo简单推理 98 | 99 | ```bash 100 | python ./src/web_demo.py 101 | ``` 102 | 103 | ### 使用接口进行推理 104 | 105 | ```bash 106 | python ./src/api_service.py 107 | ``` 108 | 109 | ### 使用常见聊天问题测试 110 | 111 | ```bash 112 | python ./src/api_service.py 113 | python ./src/test_model.py 114 | ``` 115 | 116 | ### 部署到聊天机器人 117 | 118 | #### AstrBot方案 119 | [AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。 120 | 121 | 使用步骤: 122 | 1. 部署 AstrBot 123 | 2. 在 AstrBot 中部署消息平台 124 | 3. 执行 `python ./src/api_service.py ` 启动api服务 125 | 4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo 126 | 5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off reminder`,否则会没有微调后的效果。 127 | 6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。 128 | ![alt text](img/5.png) 129 | 130 | 131 | 132 | 133 |
134 | itchat方案(已弃用) 135 | 136 | > [!IMPORTANT] 137 | > 微信有封号风险,建议使用小号,并且必须绑定银行卡才能使用 138 | 139 | ```bash 140 | python ./src/api_service.py # 先启动api服务 141 | python ./src/wechat_bot/main.py 142 | ``` 143 | 144 | 默认在终端显示二维码,扫码登录即可。可以私聊或者在群聊中@机器人使用。 145 |
146 | 147 | ### 截图 148 | 149 | ![alt text](img/4.jpg) 150 | ![alt text](img/1.png) 151 | ![alt text](img/2.png) 152 | ![alt text](img/3.png) 153 | 154 | 155 | 156 | # 免责声明 157 | > [!CAUTION] 158 | > 请勿用于非法用途,否则后果自负。 159 |
160 | 1. 使用目的 161 | 162 | * 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。 163 | * 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。 164 | 165 | 2. 使用期限 166 | 167 | * 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。 168 | 169 | 3. 操作规范 170 | 171 | * 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。 172 | * 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。 173 | 174 | 4. 免责声明接受 175 | 176 | * 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它; 177 | 178 | 5. 禁止用于非法测试或渗透 179 | 180 | * 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。 181 | * 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。 182 | 183 | 6. 免责声明修改 184 | 185 | * 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。 186 | 187 | 7. 其他 188 | 189 | * 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。 190 | 191 | * 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 192 | 193 |
194 | 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 195 | 196 |
197 |
198 |
199 | 200 | ## ⭐ Star History 201 | > [!TIP] 202 | > 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢 203 | 204 |
205 | 206 | [![Star History Chart](https://api.star-history.com/svg?repos=xming521/WeClone&type=Date)](https://www.star-history.com/#xming521/WeClone&Date) 207 | 208 |
209 | 210 | 211 |
克隆我们,保留那灵魂的芬芳
212 | -------------------------------------------------------------------------------- /WeClone-audio/README.md: -------------------------------------------------------------------------------- 1 | # WeClone-audio 模块 2 | 3 | WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。 4 | ### 显存需求 5 | **Spark-TTS** 推荐 6 | - **0.5B 模型**: 约 4GB 显存 7 | 8 | **Llasa** 9 | - **3B 模型**: 约 16GB 显存 10 | - **1B 模型**: 约 9GB 显存 11 | 12 | 13 | 14 | 15 | ## 1. 导出微信语音数据 16 | 17 | ### 1.1 准备工作 18 | - 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录 19 | - 下载软件并解密数据库 20 | - 点击聊天备份,导出类型选择"解密文件" 21 | 22 | ### 1.2 环境配置 23 | 语音导出仅支持Windows环境 24 | WeClone Audio使用uv作为包管理器。 25 | ```bash 26 | # 为 PyWxDump 创建 Python 环境和安装依赖 27 | # 28 | uv venv .venv-wx --python=3.9 29 | source .venv-wx/bin/activate 30 | # 安装 wx 依赖组 31 | uv pip install --group wx -e . 32 | ``` 33 | 34 | ### 1.3 导出语音文件 35 | ```bash 36 | python ./WeClone-audio/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段" 37 | ``` 38 | 39 | ## 2. 语音合成推理 40 | ### Spark-TTS模型 41 | 42 | **环境安装** 43 | 可不创建新环境,直接安装依赖组到WeClone共主环境 44 | 45 | ```bash 46 | uv venv .venv-sparktts --python=3.9 47 | source .venv-sparktts/bin/activate 48 | uv pip install --group sparktts -e . 49 | 50 | cd WeClone-audio/src 51 | git clone https://github.com/SparkAudio/Spark-TTS.git 52 | ``` 53 | 54 | **模型下载** 55 | 56 | 通过python下载: 57 | ```python 58 | from huggingface_hub import snapshot_download 59 | 60 | snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") 61 | ``` 62 | 63 | 或通过git下载: 64 | ```sh 65 | cd WeClone-audio 66 | mkdir -p pretrained_models 67 | 68 | # Make sure you have git-lfs installed (https://git-lfs.com) 69 | git lfs install 70 | git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B 71 | ``` 72 | 使用代码推理 73 | ```python 74 | import os 75 | import SparkTTS 76 | import soundfile as sf 77 | import torch 78 | 79 | from SparkTTS import SparkTTS 80 | 81 | model = SparkTTS("WeClone-audio/pretrained_models/Spark-TTS-0.5B", "cuda") 82 | 83 | 84 | with torch.no_grad(): 85 | wav = model.inference( 86 | text="晚上好啊,小可爱们,该睡觉了哦", 87 | prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"), 88 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", 89 | ) 90 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000) 91 | ``` 92 | ### Llasa模型 93 | ### 2.1 环境配置 94 | ```bash 95 | # 创建并配置推理环境 96 | ## 可不创建新环境,与LLaMA-Factory环境共用 97 | uv venv .venv-xcodec --python=3.9 98 | source .venv-xcodec/bin/activate 99 | uv pip install --group xcodec -e . 100 | # 退出环境 101 | deactivate 102 | 103 | # 系统依赖安装(如果需要) 104 | sudo apt install python3-dev 105 | sudo apt install build-essential 106 | ``` 107 | 108 | ### 2.2 使用代码推理 109 | 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 110 | ```python 111 | import os 112 | import soundfile as sf 113 | from text_to_speech import TextToSpeech 114 | 115 | 116 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 117 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径 118 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 119 | target_text = "晚上好啊" # 生成目标文本 120 | result = tts.infer(target_text) 121 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频 122 | ``` 123 | 124 | -------------------------------------------------------------------------------- /WeClone-audio/src/Llasa/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import soundfile as sf 3 | from text_to_speech import TextToSpeech 4 | 5 | 6 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 7 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径 8 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 9 | target_text = "晚上好啊" # 生成目标文本 10 | result = tts.infer(target_text) 11 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频 12 | 13 | -------------------------------------------------------------------------------- /WeClone-audio/src/Llasa/text_to_speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import torch 4 | import soundfile as sf 5 | from xcodec2.modeling_xcodec2 import XCodec2Model 6 | import torchaudio 7 | 8 | 9 | class TextToSpeech: 10 | def __init__(self, sample_audio_path, sample_audio_text): 11 | self.sample_audio_text = sample_audio_text 12 | # 初始化模型 13 | llasa_3b = "HKUSTAudio/Llasa-3B" 14 | xcodec2 = "HKUSTAudio/xcodec2" 15 | 16 | self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b) 17 | self.llasa_3b_model = AutoModelForCausalLM.from_pretrained( 18 | llasa_3b, 19 | trust_remote_code=True, 20 | device_map="auto", 21 | ) 22 | self.llasa_3b_model.eval() 23 | 24 | self.xcodec_model = XCodec2Model.from_pretrained(xcodec2) 25 | self.xcodec_model.eval().cuda() 26 | 27 | # 处理音频 28 | waveform, sample_rate = torchaudio.load(sample_audio_path) 29 | if len(waveform[0]) / sample_rate > 15: 30 | print("已将音频裁剪至前15秒。") 31 | waveform = waveform[:, : sample_rate * 15] 32 | 33 | # 检查音频是否为立体声 34 | if waveform.size(0) > 1: 35 | waveform_mono = torch.mean(waveform, dim=0, keepdim=True) 36 | else: 37 | waveform_mono = waveform 38 | 39 | self.prompt_wav = torchaudio.transforms.Resample( 40 | orig_freq=sample_rate, new_freq=16000 41 | )(waveform_mono) 42 | 43 | # Encode the prompt wav 44 | vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav) 45 | vq_code_prompt = vq_code_prompt[0, 0, :] 46 | self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) 47 | self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") 48 | 49 | def ids_to_speech_tokens(self, speech_ids): 50 | speech_tokens_str = [] 51 | for speech_id in speech_ids: 52 | speech_tokens_str.append(f"<|s_{speech_id}|>") 53 | return speech_tokens_str 54 | 55 | def extract_speech_ids(self, speech_tokens_str): 56 | speech_ids = [] 57 | for token_str in speech_tokens_str: 58 | if token_str.startswith("<|s_") and token_str.endswith("|>"): 59 | num_str = token_str[4:-2] 60 | num = int(num_str) 61 | speech_ids.append(num) 62 | else: 63 | print(f"Unexpected token: {token_str}") 64 | return speech_ids 65 | 66 | @torch.inference_mode() 67 | def infer(self, target_text): 68 | if len(target_text) == 0: 69 | return None 70 | elif len(target_text) > 300: 71 | print("文本过长,请保持在300字符以内。") 72 | target_text = target_text[:300] 73 | 74 | input_text = self.sample_audio_text + " " + target_text 75 | 76 | formatted_text = ( 77 | f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" 78 | ) 79 | 80 | chat = [ 81 | { 82 | "role": "user", 83 | "content": "Convert the text to speech:" + formatted_text, 84 | }, 85 | { 86 | "role": "assistant", 87 | "content": "<|SPEECH_GENERATION_START|>" 88 | + "".join(self.speech_ids_prefix), 89 | }, 90 | ] 91 | 92 | input_ids = self.tokenizer.apply_chat_template( 93 | chat, tokenize=True, return_tensors="pt", continue_final_message=True 94 | ) 95 | input_ids = input_ids.to("cuda") 96 | 97 | outputs = self.llasa_3b_model.generate( 98 | input_ids, 99 | max_length=2048, 100 | eos_token_id=self.speech_end_id, 101 | do_sample=True, 102 | top_p=1, 103 | temperature=0.8, 104 | ) 105 | generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1] 106 | 107 | speech_tokens = self.tokenizer.batch_decode( 108 | generated_ids, skip_special_tokens=True 109 | ) 110 | 111 | speech_tokens = self.extract_speech_ids(speech_tokens) 112 | speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) 113 | 114 | gen_wav = self.xcodec_model.decode_code(speech_tokens) 115 | gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:] 116 | 117 | return (16000, gen_wav[0, 0, :].cpu().numpy()) 118 | 119 | 120 | if __name__ == "__main__": 121 | # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 122 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" 123 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") 124 | 125 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 126 | target_text = "晚上好啊,吃了吗您" 127 | result = tts.infer(target_text) 128 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) 129 | target_text = "我是老北京正黄旗!" 130 | result = tts.infer(target_text) 131 | sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0]) 132 | -------------------------------------------------------------------------------- /WeClone-audio/src/SparkTTS.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from typing import Tuple 4 | from pathlib import Path 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS"))) 9 | from sparktts.utils.file import load_config 10 | from sparktts.models.audio_tokenizer import BiCodecTokenizer 11 | from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP 12 | 13 | 14 | class SparkTTS: 15 | """ 16 | Spark-TTS for text-to-speech generation. 17 | """ 18 | 19 | def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): 20 | """ 21 | Initializes the SparkTTS model with the provided configurations and device. 22 | 23 | Args: 24 | model_dir (Path): Directory containing the model and config files. 25 | device (torch.device): The device (CPU/GPU) to run the model on. 26 | """ 27 | self.device = device 28 | self.model_dir = model_dir 29 | self.configs = load_config(f"{model_dir}/config.yaml") 30 | self.sample_rate = self.configs["sample_rate"] 31 | self._initialize_inference() 32 | 33 | def _initialize_inference(self): 34 | """Initializes the tokenizer, model, and audio tokenizer for inference.""" 35 | self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") 36 | self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") 37 | self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) 38 | self.model.to(self.device) 39 | 40 | def process_prompt( 41 | self, 42 | text: str, 43 | prompt_speech_path: Path, 44 | prompt_text: str = None, 45 | ) -> Tuple[str, torch.Tensor]: 46 | """ 47 | Process input for voice cloning. 48 | 49 | Args: 50 | text (str): The text input to be converted to speech. 51 | prompt_speech_path (Path): Path to the audio file used as a prompt. 52 | prompt_text (str, optional): Transcript of the prompt audio. 53 | 54 | Return: 55 | Tuple[str, torch.Tensor]: Input prompt; global tokens 56 | """ 57 | 58 | global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( 59 | prompt_speech_path 60 | ) 61 | global_tokens = "".join( 62 | [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] 63 | ) 64 | 65 | # Prepare the input tokens for the model 66 | if prompt_text is not None: 67 | semantic_tokens = "".join( 68 | [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] 69 | ) 70 | inputs = [ 71 | TASK_TOKEN_MAP["tts"], 72 | "<|start_content|>", 73 | prompt_text, 74 | text, 75 | "<|end_content|>", 76 | "<|start_global_token|>", 77 | global_tokens, 78 | "<|end_global_token|>", 79 | "<|start_semantic_token|>", 80 | semantic_tokens, 81 | ] 82 | else: 83 | inputs = [ 84 | TASK_TOKEN_MAP["tts"], 85 | "<|start_content|>", 86 | text, 87 | "<|end_content|>", 88 | "<|start_global_token|>", 89 | global_tokens, 90 | "<|end_global_token|>", 91 | ] 92 | 93 | inputs = "".join(inputs) 94 | 95 | return inputs, global_token_ids 96 | 97 | def process_prompt_control( 98 | self, 99 | gender: str, 100 | pitch: str, 101 | speed: str, 102 | text: str, 103 | ): 104 | """ 105 | Process input for voice creation. 106 | 107 | Args: 108 | gender (str): female | male. 109 | pitch (str): very_low | low | moderate | high | very_high 110 | speed (str): very_low | low | moderate | high | very_high 111 | text (str): The text input to be converted to speech. 112 | 113 | Return: 114 | str: Input prompt 115 | """ 116 | assert gender in GENDER_MAP.keys() 117 | assert pitch in LEVELS_MAP.keys() 118 | assert speed in LEVELS_MAP.keys() 119 | 120 | gender_id = GENDER_MAP[gender] 121 | pitch_level_id = LEVELS_MAP[pitch] 122 | speed_level_id = LEVELS_MAP[speed] 123 | 124 | pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" 125 | speed_label_tokens = f"<|speed_label_{speed_level_id}|>" 126 | gender_tokens = f"<|gender_{gender_id}|>" 127 | 128 | attribte_tokens = "".join( 129 | [gender_tokens, pitch_label_tokens, speed_label_tokens] 130 | ) 131 | 132 | control_tts_inputs = [ 133 | TASK_TOKEN_MAP["controllable_tts"], 134 | "<|start_content|>", 135 | text, 136 | "<|end_content|>", 137 | "<|start_style_label|>", 138 | attribte_tokens, 139 | "<|end_style_label|>", 140 | ] 141 | 142 | return "".join(control_tts_inputs) 143 | 144 | @torch.no_grad() 145 | def inference( 146 | self, 147 | text: str, 148 | prompt_speech_path: Path = None, 149 | prompt_text: str = None, 150 | gender: str = None, 151 | pitch: str = None, 152 | speed: str = None, 153 | temperature: float = 0.8, 154 | top_k: float = 50, 155 | top_p: float = 0.95, 156 | ) -> torch.Tensor: 157 | """ 158 | Performs inference to generate speech from text, incorporating prompt audio and/or text. 159 | 160 | Args: 161 | text (str): The text input to be converted to speech. 162 | prompt_speech_path (Path): Path to the audio file used as a prompt. 163 | prompt_text (str, optional): Transcript of the prompt audio. 164 | gender (str): female | male. 165 | pitch (str): very_low | low | moderate | high | very_high 166 | speed (str): very_low | low | moderate | high | very_high 167 | temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. 168 | top_k (float, optional): Top-k sampling parameter. Default is 50. 169 | top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. 170 | 171 | Returns: 172 | torch.Tensor: Generated waveform as a tensor. 173 | """ 174 | if gender is not None: 175 | prompt = self.process_prompt_control(gender, pitch, speed, text) 176 | 177 | else: 178 | prompt, global_token_ids = self.process_prompt( 179 | text, prompt_speech_path, prompt_text 180 | ) 181 | model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) 182 | 183 | # Generate speech using the model 184 | generated_ids = self.model.generate( 185 | **model_inputs, 186 | max_new_tokens=3000, 187 | do_sample=True, 188 | top_k=top_k, 189 | top_p=top_p, 190 | temperature=temperature, 191 | ) 192 | 193 | # Trim the output tokens to remove the input tokens 194 | generated_ids = [ 195 | output_ids[len(input_ids):] 196 | for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 197 | ] 198 | 199 | # Decode the generated tokens into text 200 | predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 201 | 202 | # Extract semantic token IDs from the generated text 203 | pred_semantic_ids = ( 204 | torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) 205 | .long() 206 | .unsqueeze(0) 207 | ) 208 | 209 | if gender is not None: 210 | global_token_ids = ( 211 | torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) 212 | .long() 213 | .unsqueeze(0) 214 | .unsqueeze(0) 215 | ) 216 | 217 | # Convert semantic tokens back to waveform 218 | wav = self.audio_tokenizer.detokenize( 219 | global_token_ids.to(self.device).squeeze(0), 220 | pred_semantic_ids.to(self.device), 221 | ) 222 | 223 | return wav 224 | -------------------------------------------------------------------------------- /WeClone-audio/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/WeClone-audio/src/__init__.py -------------------------------------------------------------------------------- /WeClone-audio/src/get_sample_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pywxdump.db import MediaHandler 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser(description='Extract audio from WeChat database') 7 | parser.add_argument('--db-path', type=str, required=True, 8 | help='Path to WeChat database file') 9 | parser.add_argument('--MsgSvrID', type=str, required=True, 10 | help='Message server ID of the audio') 11 | parser.add_argument('--save-path', type=str, 12 | default=os.path.join(os.path.dirname(__file__), 'sample.wav'), 13 | help='Path to save the audio file (default: sample.wav in script directory)') 14 | parser.add_argument('--rate', type=int, default=24000, 15 | help='Sample rate for audio conversion (default: 24000)') 16 | 17 | args = parser.parse_args() 18 | 19 | config = { 20 | "key": "test1", 21 | "type": "sqlite", 22 | "path": args.db_path, 23 | } 24 | 25 | t1 = MediaHandler(config) 26 | t1.get_audio( 27 | MsgSvrID=args.msg_id, 28 | is_play=True, 29 | is_wave=True, 30 | save_path=args.save_path, 31 | rate=args.rate, 32 | ) 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /WeClone-audio/src/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SparkTTS 3 | import soundfile as sf 4 | import torch 5 | 6 | from SparkTTS import SparkTTS 7 | 8 | model = SparkTTS("WeClone-audio/pretrained_models/Spark-TTS-0.5B", "cuda") 9 | 10 | 11 | with torch.no_grad(): 12 | wav = model.inference( 13 | text="晚上好啊,小可爱们,该睡觉了哦", 14 | prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"), 15 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", 16 | ) 17 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000) 18 | print("生成成功!") 19 | -------------------------------------------------------------------------------- /WeClone-audio/src/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/WeClone-audio/src/sample.wav -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/.env.example: -------------------------------------------------------------------------------- 1 | API_KEY=your_api_key_here 2 | PORT=5050 3 | 4 | DEFAULT_VOICE=en-US-AvaNeural 5 | DEFAULT_RESPONSE_FORMAT=mp3 6 | DEFAULT_SPEED=1.0 7 | 8 | DEFAULT_LANGUAGE=en-US 9 | 10 | REQUIRE_API_KEY=True 11 | 12 | REMOVE_FILTER=False 13 | 14 | EXPAND_API=True -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/handle_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | import emoji 3 | 4 | def prepare_tts_input_with_context(text: str) -> str: 5 | """ 6 | Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints 7 | for certain Markdown elements like headers. Preserves paragraph separation. 8 | 9 | Args: 10 | text (str): The raw text containing Markdown or other formatting. 11 | 12 | Returns: 13 | str: Cleaned text with contextual hints suitable for TTS input. 14 | """ 15 | 16 | # Remove emojis 17 | text = emoji.replace_emoji(text, replace='') 18 | 19 | # Add context for headers 20 | def header_replacer(match): 21 | level = len(match.group(1)) # Number of '#' symbols 22 | header_text = match.group(2).strip() 23 | if level == 1: 24 | return f"Title — {header_text}\n" 25 | elif level == 2: 26 | return f"Section — {header_text}\n" 27 | else: 28 | return f"Subsection — {header_text}\n" 29 | 30 | text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE) 31 | 32 | # Announce links (currently commented out for potential future use) 33 | # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text) 34 | 35 | # Remove links while keeping the link text 36 | text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) 37 | 38 | # Describe inline code 39 | text = re.sub(r"`([^`]+)`", r"code snippet: \1", text) 40 | 41 | # Remove bold/italic symbols but keep the content 42 | text = re.sub(r"(\*\*|__|\*|_)", '', text) 43 | 44 | # Remove code blocks (multi-line) with a description 45 | text = re.sub(r"```([\s\S]+?)```", r"(code block omitted)", text) 46 | 47 | # Remove image syntax but add alt text if available 48 | text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text) 49 | 50 | # Remove HTML tags 51 | text = re.sub(r"]+(>|$)", '', text) 52 | 53 | # Normalize line breaks 54 | text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation 55 | 56 | # Replace multiple spaces within lines 57 | text = re.sub(r" {2,}", ' ', text) 58 | 59 | # Trim leading and trailing whitespace from the whole text 60 | text = text.strip() 61 | 62 | return text 63 | -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | gevent 3 | python-dotenv 4 | edge-tts 5 | emoji -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/server.py: -------------------------------------------------------------------------------- 1 | # server.py 2 | 3 | from flask import Flask, request, send_file, jsonify 4 | from gevent.pywsgi import WSGIServer 5 | from dotenv import load_dotenv 6 | import os 7 | 8 | from handle_text import prepare_tts_input_with_context 9 | from tts_handler import generate_speech, get_models, get_voices 10 | from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES 11 | 12 | app = Flask(__name__) 13 | load_dotenv() 14 | 15 | API_KEY = os.getenv('API_KEY', 'your_api_key_here') 16 | PORT = int(os.getenv('PORT', 5050)) 17 | 18 | DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural') 19 | DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3') 20 | DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0)) 21 | 22 | REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False) 23 | EXPAND_API = getenv_bool('EXPAND_API', True) 24 | 25 | # DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1') 26 | 27 | @app.route('/v1/audio/speech', methods=['POST']) 28 | @app.route('/audio/speech', methods=['POST']) # Add this line for the alias 29 | @require_api_key 30 | def text_to_speech(): 31 | data = request.json 32 | if not data or 'input' not in data: 33 | return jsonify({"error": "Missing 'input' in request body"}), 400 34 | 35 | text = data.get('input') 36 | 37 | if not REMOVE_FILTER: 38 | text = prepare_tts_input_with_context(text) 39 | 40 | # model = data.get('model', DEFAULT_MODEL) 41 | voice = data.get('voice', DEFAULT_VOICE) 42 | 43 | response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT) 44 | speed = float(data.get('speed', DEFAULT_SPEED)) 45 | 46 | mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg") 47 | 48 | # Generate the audio file in the specified format with speed adjustment 49 | output_file_path = generate_speech(text, voice, response_format, speed) 50 | 51 | # Return the file with the correct MIME type 52 | return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}") 53 | 54 | @app.route('/v1/models', methods=['GET', 'POST']) 55 | @app.route('/models', methods=['GET', 'POST']) 56 | @require_api_key 57 | def list_models(): 58 | return jsonify({"data": get_models()}) 59 | 60 | @app.route('/v1/voices', methods=['GET', 'POST']) 61 | @app.route('/voices', methods=['GET', 'POST']) 62 | @require_api_key 63 | def list_voices(): 64 | specific_language = None 65 | 66 | data = request.args if request.method == 'GET' else request.json 67 | if data and ('language' in data or 'locale' in data): 68 | specific_language = data.get('language') if 'language' in data else data.get('locale') 69 | 70 | return jsonify({"voices": get_voices(specific_language)}) 71 | 72 | @app.route('/v1/voices/all', methods=['GET', 'POST']) 73 | @app.route('/voices/all', methods=['GET', 'POST']) 74 | @require_api_key 75 | def list_all_voices(): 76 | return jsonify({"voices": get_voices('all')}) 77 | 78 | """ 79 | Support for ElevenLabs and Azure AI Speech 80 | (currently in beta) 81 | """ 82 | 83 | # http://localhost:5050/elevenlabs/v1/text-to-speech 84 | # http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural 85 | @app.route('/elevenlabs/v1/text-to-speech/', methods=['POST']) 86 | @require_api_key 87 | def elevenlabs_tts(voice_id): 88 | if not EXPAND_API: 89 | return jsonify({"error": f"Endpoint not allowed"}), 500 90 | 91 | # Parse the incoming JSON payload 92 | try: 93 | payload = request.json 94 | if not payload or 'text' not in payload: 95 | return jsonify({"error": "Missing 'text' in request body"}), 400 96 | except Exception as e: 97 | return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400 98 | 99 | text = payload['text'] 100 | 101 | if not REMOVE_FILTER: 102 | text = prepare_tts_input_with_context(text) 103 | 104 | voice = voice_id # ElevenLabs uses the voice_id in the URL 105 | 106 | # Use default settings for edge-tts 107 | response_format = 'mp3' 108 | speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED) 109 | 110 | # Generate speech using edge-tts 111 | try: 112 | output_file_path = generate_speech(text, voice, response_format, speed) 113 | except Exception as e: 114 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 115 | 116 | # Return the generated audio file 117 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") 118 | 119 | # tts.speech.microsoft.com/cognitiveservices/v1 120 | # https://{region}.tts.speech.microsoft.com/cognitiveservices/v1 121 | # http://localhost:5050/azure/cognitiveservices/v1 122 | @app.route('/azure/cognitiveservices/v1', methods=['POST']) 123 | @require_api_key 124 | def azure_tts(): 125 | if not EXPAND_API: 126 | return jsonify({"error": f"Endpoint not allowed"}), 500 127 | 128 | # Parse the SSML payload 129 | try: 130 | ssml_data = request.data.decode('utf-8') 131 | if not ssml_data: 132 | return jsonify({"error": "Missing SSML payload"}), 400 133 | 134 | # Extract the text and voice from SSML 135 | from xml.etree import ElementTree as ET 136 | root = ET.fromstring(ssml_data) 137 | text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text 138 | voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name') 139 | except Exception as e: 140 | return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400 141 | 142 | # Use default settings for edge-tts 143 | response_format = 'mp3' 144 | speed = DEFAULT_SPEED 145 | 146 | if not REMOVE_FILTER: 147 | text = prepare_tts_input_with_context(text) 148 | 149 | # Generate speech using edge-tts 150 | try: 151 | output_file_path = generate_speech(text, voice, response_format, speed) 152 | except Exception as e: 153 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 154 | 155 | # Return the generated audio file 156 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") 157 | 158 | print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API") 159 | print(f" ") 160 | print(f" * Serving OpenAI Edge TTS") 161 | print(f" * Server running on http://localhost:{PORT}") 162 | print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech") 163 | print(f" ") 164 | 165 | if __name__ == '__main__': 166 | http_server = WSGIServer(('0.0.0.0', PORT), app) 167 | http_server.serve_forever() 168 | -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/tts_handler.py: -------------------------------------------------------------------------------- 1 | import edge_tts 2 | import asyncio 3 | import tempfile 4 | import subprocess 5 | import os 6 | from pathlib import Path 7 | 8 | # Language default (environment variable) 9 | DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US') 10 | 11 | # OpenAI voice names mapped to edge-tts equivalents 12 | voice_mapping = { 13 | 'alloy': 'en-US-AvaNeural', 14 | 'echo': 'en-US-AndrewNeural', 15 | 'fable': 'en-GB-SoniaNeural', 16 | 'onyx': 'en-US-EricNeural', 17 | 'nova': 'en-US-SteffanNeural', 18 | 'shimmer': 'en-US-EmmaNeural' 19 | } 20 | 21 | def is_ffmpeg_installed(): 22 | """Check if FFmpeg is installed and accessible.""" 23 | try: 24 | subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 25 | return True 26 | except (subprocess.CalledProcessError, FileNotFoundError): 27 | return False 28 | 29 | async def _generate_audio(text, voice, response_format, speed): 30 | """Generate TTS audio and optionally convert to a different format.""" 31 | # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice 32 | edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is 33 | 34 | # Generate the TTS output in mp3 format first 35 | temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") 36 | 37 | # Convert speed to SSML rate format 38 | try: 39 | speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%" 40 | except Exception as e: 41 | print(f"Error converting speed: {e}. Defaulting to +0%.") 42 | speed_rate = "+0%" 43 | 44 | # Generate the MP3 file 45 | communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate) 46 | await communicator.save(temp_output_file.name) 47 | 48 | # If the requested format is mp3, return the generated file directly 49 | if response_format == "mp3": 50 | return temp_output_file.name 51 | 52 | # Check if FFmpeg is installed 53 | if not is_ffmpeg_installed(): 54 | print("FFmpeg is not available. Returning unmodified mp3 file.") 55 | return temp_output_file.name 56 | 57 | # Create a new temporary file for the converted output 58 | converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}") 59 | 60 | # Build the FFmpeg command 61 | ffmpeg_command = [ 62 | "ffmpeg", 63 | "-i", temp_output_file.name, # Input file 64 | "-c:a", { 65 | "aac": "aac", 66 | "mp3": "libmp3lame", 67 | "wav": "pcm_s16le", 68 | "opus": "libopus", 69 | "flac": "flac" 70 | }.get(response_format, "aac"), # Default to AAC if unknown 71 | "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV 72 | "-f", { 73 | "aac": "mp4", # AAC in MP4 container 74 | "mp3": "mp3", 75 | "wav": "wav", 76 | "opus": "ogg", 77 | "flac": "flac" 78 | }.get(response_format, response_format), # Default to matching format 79 | "-y", # Overwrite without prompt 80 | converted_output_file.name # Output file 81 | ] 82 | 83 | try: 84 | # Run FFmpeg command and ensure no errors occur 85 | subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 86 | except subprocess.CalledProcessError as e: 87 | raise RuntimeError(f"FFmpeg error during audio conversion: {e}") 88 | 89 | # Clean up the original temporary file 90 | Path(temp_output_file.name).unlink(missing_ok=True) 91 | 92 | return converted_output_file.name 93 | 94 | def generate_speech(text, voice, response_format, speed=1.0): 95 | return asyncio.run(_generate_audio(text, voice, response_format, speed)) 96 | 97 | def get_models(): 98 | return [ 99 | {"id": "tts-1", "name": "Text-to-speech v1"}, 100 | {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"} 101 | ] 102 | 103 | async def _get_voices(language=None): 104 | # List all voices, filter by language if specified 105 | all_voices = await edge_tts.list_voices() 106 | language = language or DEFAULT_LANGUAGE # Use default if no language specified 107 | filtered_voices = [ 108 | {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']} 109 | for v in all_voices if language == 'all' or language is None or v['Locale'] == language 110 | ] 111 | return filtered_voices 112 | 113 | def get_voices(language=None): 114 | return asyncio.run(_get_voices(language)) 115 | 116 | def speed_to_rate(speed: float) -> str: 117 | """ 118 | Converts a multiplicative speed value to the edge-tts "rate" format. 119 | 120 | Args: 121 | speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%). 122 | 123 | Returns: 124 | str: The formatted "rate" string (e.g., "+50%" or "-50%"). 125 | """ 126 | if speed < 0 or speed > 2: 127 | raise ValueError("Speed must be between 0 and 2 (inclusive).") 128 | 129 | # Convert speed to percentage change 130 | percentage_change = (speed - 1) * 100 131 | 132 | # Format with a leading "+" or "-" as required 133 | return f"{percentage_change:+.0f}%" 134 | -------------------------------------------------------------------------------- /WeClone-audio/src/server未完工/utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | 3 | from flask import request, jsonify 4 | from functools import wraps 5 | import os 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | def getenv_bool(name: str, default: bool = False) -> bool: 11 | return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t") 12 | 13 | API_KEY = os.getenv('API_KEY', 'your_api_key_here') 14 | REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True) 15 | 16 | def require_api_key(f): 17 | @wraps(f) 18 | def decorated_function(*args, **kwargs): 19 | if not REQUIRE_API_KEY: 20 | return f(*args, **kwargs) 21 | auth_header = request.headers.get('Authorization') 22 | if not auth_header or not auth_header.startswith('Bearer '): 23 | return jsonify({"error": "Missing or invalid API key"}), 401 24 | token = auth_header.split('Bearer ')[1] 25 | if token != API_KEY: 26 | return jsonify({"error": "Invalid API key"}), 401 27 | return f(*args, **kwargs) 28 | return decorated_function 29 | 30 | # Mapping of audio format to MIME type 31 | AUDIO_FORMAT_MIME_TYPES = { 32 | "mp3": "audio/mpeg", 33 | "opus": "audio/ogg", 34 | "aac": "audio/aac", 35 | "flac": "audio/flac", 36 | "wav": "audio/wav", 37 | "pcm": "audio/L16" 38 | } 39 | -------------------------------------------------------------------------------- /data/example_chat.csv: -------------------------------------------------------------------------------- 1 | id,MsgSvrID,type_name,is_sender,talker,room_name,content,CreateTime 2 | 1,953020244103908134,系统通知,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""你已添加了白桃乌龙,现在可以开始聊天了。""}",2023-08-12 23:22:41 3 | 2,7594861486645126963,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哦吼""}",2023-08-12 23:22:54 4 | 3,5795621731176683438,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""咱们老家""}",2023-08-12 23:23:02 5 | 4,3470072877112832166,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""离得近嘞""}",2023-08-12 23:23:05 6 | 5,4123958315588848926,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=8f29c306e125f7514246f4f82887cc29&filekey=30350201010421301f020201060402534804108f29c306e125f7514246f4f82887cc290203008884040d00000004627466730000000131&hy=SH&storeid=32303232303531373136313532383030306137303837333237346335383237653661386430393030303030313036&bizid=1023"", ""msg"": ""表情""}",2023-08-12 23:23:26 7 | 6,1893107158254293225,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""是耶""}",2023-08-12 23:23:31 8 | 7,5967747793341844139,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""挺好""}",2023-08-12 23:23:37 9 | 8,365944262262687504,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈""}",2023-08-12 23:23:41 10 | 9,4210929420150253626,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""困嘞困嘞""}",2023-08-12 23:24:26 11 | 10,3243868033522068109,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""俺先睡了""}",2023-08-12 23:24:30 12 | 11,7049787380472735099,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""明天再聊""}",2023-08-12 23:24:33 13 | 12,3092986777523361373,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好der""}",2023-08-12 23:24:47 14 | 13,8642915817867986857,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://mmbiz.qpic.cn/mmemoticon/ajNVdqHZLLACNZkcWvHBNURJvCPe6PQ3BDf8deqLjHGricqHbCbV0iagpHmCat5Mo3/0"", ""msg"": ""表情""}",2023-08-12 23:24:50 15 | 14,422457693648526604,动画表情,1,我,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=65a36156989a0bd09bd3d8b8fee3cb2f&filekey=30350201010421301f020201060402535a041065a36156989a0bd09bd3d8b8fee3cb2f0203008069040d00000004627466730000000132&hy=SZ&storeid=26438315e000ec0f96c8595840000010600004f50535a0b492a91e683f6a79&bizid=1023"", ""msg"": ""表情""}",2023-08-12 23:25:04 16 | 15,2311761750664470975,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\82332bbba85f6bb3cf85823037f3a99e.dat"", ""msg"": ""图片""}",2023-08-13 14:01:41 17 | 16,4293887349744541268,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\9c2e52835ddf872089b6fff3497d6f44.dat"", ""msg"": ""图片""}",2023-08-13 14:01:43 18 | 17,5053822429327172633,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""今天天气不错嘞""}",2023-08-13 14:01:46 19 | 18,2705637163783232848,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""就是有点hot""}",2023-08-13 14:02:07 20 | 19,9106315879092141637,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/275/20304/stodownload?m=39ec76de39cf2f41e0f19fae9752f2e1&filekey=30340201010420301e0202011304025348041039ec76de39cf2f41e0f19fae9752f2e102024a31040d00000004627466730000000132&hy=SH&storeid=264795d250002b74e000000000000011300004f50534801fe5b01e648f3633&bizid=1023"", ""msg"": ""表情""}",2023-08-13 14:05:59 21 | 20,4919372022511043672,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哪天不hot""}",2023-08-13 14:06:04 22 | 21,7801804847925160451,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""下雨不hot""}",2023-08-13 14:10:41 23 | 22,4186719283759563990,图片,1,我,xinglaifaxianzhe,"{""src"": ""FileStorage\\MsgAttach\\ec71e3b5cc65b1713665080d0f1939e5\\Image\\2023-08\\08b1106d78b4dfba0755ff2aa09d2438.dat"", ""msg"": ""图片""}",2023-08-13 14:11:06 24 | 23,3521100079205966470,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好像空格键""}",2023-08-13 14:11:07 25 | 24,1617187845718151808,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""很想按下去""}",2023-08-13 14:11:18 26 | 25,8345209503359387241,动画表情,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": ""http://wxapp.tc.qq.com/262/20304/stodownload?m=004f3c0cf5259ab214d9dabacc337c60&filekey=30350201010421301f02020106040253480410004f3c0cf5259ab214d9dabacc337c600203059dc6040d00000004627466730000000131&hy=SH&storeid=32303232303831393136353732363030306462663664303030303030303030376464613030623030303030313036&bizid=1023"", ""msg"": ""表情""}",2023-08-13 14:13:06 27 | 26,8573385017731442447,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""不行你干脆回家学习""}",2023-08-13 14:13:12 28 | 27,7076070838462900480,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈""}",2023-08-13 14:13:13 29 | 28,5591518902965619475,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""这是要去哪玩""}",2023-08-13 14:13:34 30 | 29,8086920793435315850,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""来看个电影""}",2023-08-13 14:13:51 31 | 30,1814934139736232801,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""看什么!""}",2023-08-13 14:14:49 32 | 31,3677613846620059746,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""emmm""}",2023-08-13 14:14:56 33 | 32,3409906469387839600,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""我也在纠结要不要出门看个电影""}",2023-08-13 14:15:03 34 | 33,5709723437718492173,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""好像是孤注一掷""}",2023-08-13 14:15:07 35 | 34,7775547367225903484,文本,1,我,xinglaifaxianzhe,"{""src"": """", ""msg"": ""但是感觉像反诈电影""}",2023-08-13 14:15:13 36 | 35,6258879459298954092,文本,0,xinglaifaxianzhe,xinglaifaxianzhe,"{""src"": """", ""msg"": ""哈哈哈好像是""}",2023-08-13 14:15:14 -------------------------------------------------------------------------------- /data/res_csv/pt/dataset_info.json: -------------------------------------------------------------------------------- 1 | {"wechat-pt":{ 2 | "file_name": "./pt-my.json", 3 | "columns": { 4 | "prompt": "c" 5 | } 6 | }} -------------------------------------------------------------------------------- /data/res_csv/sft/dataset_info-with-his.json: -------------------------------------------------------------------------------- 1 | { 2 | "wechat-sft": { 3 | "file_name": "./sft-my.json", 4 | "columns": { 5 | "prompt": "instruction", 6 | "response": "output", 7 | "history": "history" 8 | } 9 | } 10 | } -------------------------------------------------------------------------------- /data/res_csv/sft/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "wechat-sft": { 3 | "file_name": "./sft-my.json", 4 | "columns": { 5 | "prompt": "instruction", 6 | "response": "output" 7 | } 8 | } 9 | } 10 | 11 | 12 | -------------------------------------------------------------------------------- /data/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "questions": [ 3 | [ 4 | "吃了吗?", 5 | "吃的什么啊", 6 | "好吃吗", 7 | "多少钱啊", 8 | "可以请我吃吗" 9 | ], 10 | [ 11 | "你多大了?" 12 | ], 13 | [ 14 | "你有什么爱好吗?" 15 | ], 16 | [ 17 | "你的理想是什么?", 18 | "你觉得你离你的理想还有多远?" 19 | ], 20 | [ 21 | "你最近在忙什么?", 22 | "工作/学习顺利吗?", 23 | "有什么有趣的事情发生吗?" 24 | ], 25 | [ 26 | "你喜欢看什么类型的电影?", 27 | "最近看过什么好看的电影吗?", 28 | "你最喜欢的电影是什么?" 29 | ], 30 | [ 31 | "你平时喜欢听什么音乐?", 32 | "有推荐的歌手或乐队吗?", 33 | "最近有喜欢的歌曲吗?" 34 | ], 35 | [ 36 | "你喜欢旅游吗?", 37 | "去过哪些地方?", 38 | "最喜欢的旅游地是哪里?" 39 | ], 40 | [ 41 | "你喜欢读书吗?", 42 | "最近在读什么书?", 43 | "最喜欢的书是哪本?" 44 | ], 45 | [ 46 | "你平时喜欢运动吗?", 47 | "喜欢做哪些运动?", 48 | "有固定去锻炼吗?" 49 | ], 50 | [ 51 | "周末一般都做些什么?", 52 | "有没有什么特别的计划?", 53 | "周末喜欢宅在家还是出去玩?" 54 | ], 55 | [ 56 | "你喜欢宠物吗?", 57 | "有养宠物吗?", 58 | "最喜欢什么动物?" 59 | ], 60 | [ 61 | "你喜欢吃什么类型的食物?", 62 | "有推荐的餐厅吗?", 63 | "最喜欢的菜是什么?" 64 | ], 65 | [ 66 | "你喜欢什么样的天气?", 67 | "最喜欢的季节是哪一个?", 68 | "你觉得今天的天气怎么样?" 69 | ], 70 | [ 71 | "你有看电视剧的习惯吗?", 72 | "最近在追哪部剧?", 73 | "最喜欢的电视剧是哪部?" 74 | ], 75 | [ 76 | "你喜欢玩游戏吗?", 77 | "最近在玩什么游戏?", 78 | "有推荐的好玩的游戏吗?" 79 | ], 80 | [ 81 | "你会做饭吗?", 82 | "平时喜欢做哪些菜?", 83 | "有没有特别拿手的菜?" 84 | ], 85 | [ 86 | "你喜欢购物吗?", 87 | "最近买了什么新东西?", 88 | "有推荐的购物网站或店铺吗?" 89 | ], 90 | [ 91 | "你平时怎么放松自己?", 92 | "有特别的解压方式吗?", 93 | "最喜欢的放松活动是什么?" 94 | ], 95 | [ 96 | "你喜欢和朋友出去玩吗?", 97 | "平时会和朋友去哪玩?", 98 | "最近有没有和朋友聚会的计划?" 99 | ], 100 | [ 101 | "你喜欢喝咖啡还是茶?", 102 | "有没有特别喜欢的咖啡馆或茶馆?", 103 | "最喜欢的饮品是什么?" 104 | ], 105 | [ 106 | "你有兄弟姐妹吗?", 107 | "和他们关系怎么样?", 108 | "经常联系吗?" 109 | ], 110 | [ 111 | "你喜欢读什么类型的杂志?", 112 | "最近有看什么有趣的文章吗?", 113 | "有订阅的杂志吗?" 114 | ], 115 | [ 116 | "你喜欢看体育比赛吗?", 117 | "最喜欢的运动项目是什么?", 118 | "有没有特别支持的球队或运动员?" 119 | ], 120 | [ 121 | "你会说其他语言吗?", 122 | "最想学的语言是什么?", 123 | "学习语言有什么技巧吗?" 124 | ], 125 | [ 126 | "你对科技产品感兴趣吗?", 127 | "最近有没有关注什么新科技?", 128 | "最喜欢的电子产品是什么?" 129 | ], 130 | [ 131 | "你喜欢喝什么样的饮料?", 132 | "有没有自己调饮料的习惯?", 133 | "最喜欢的饮品品牌是什么?" 134 | ], 135 | [ 136 | "你平时用社交媒体吗?", 137 | "常用哪些平台?", 138 | "在社交媒体上做什么?" 139 | ], 140 | [ 141 | "你对艺术感兴趣吗?", 142 | "最喜欢的艺术家是谁?", 143 | "有去过哪些艺术展览?" 144 | ], 145 | [ 146 | "你喜欢DIY吗?", 147 | "平时做些什么手工?", 148 | "有没有完成的作品可以分享?" 149 | ], 150 | [ 151 | "你喜欢种植植物吗?", 152 | "有养什么植物?", 153 | "最喜欢的植物是什么?" 154 | ], 155 | [ 156 | "你喜欢拍照吗?", 157 | "喜欢拍什么样的照片?", 158 | "有没有用什么特别的摄影设备?" 159 | ], 160 | [ 161 | "你喜欢听播客吗?", 162 | "常听哪些主题的播客?", 163 | "有没有推荐的播客?" 164 | ], 165 | [ 166 | "你对历史感兴趣吗?", 167 | "最喜欢哪个历史时期?", 168 | "有没有特别喜欢的历史人物?" 169 | ], 170 | [ 171 | "你喜欢画画吗?", 172 | "平时画什么类型的画?", 173 | "有参加过画展吗?" 174 | ], 175 | [ 176 | "你喜欢写作吗?", 177 | "平时写什么类型的文章?", 178 | "有没有发表过作品?" 179 | ], 180 | [ 181 | "你喜欢钓鱼吗?", 182 | "平时去哪里钓鱼?", 183 | "有没有钓到过什么大鱼?" 184 | ], 185 | [ 186 | "你喜欢露营吗?", 187 | "平时会去哪里露营?", 188 | "有没有什么难忘的露营经历?" 189 | ], 190 | [ 191 | "你喜欢摄影吗?", 192 | "最喜欢拍什么题材?", 193 | "有没有特别喜欢的摄影师?" 194 | ], 195 | [ 196 | "你喜欢喝酒吗?", 197 | "喜欢什么类型的酒?", 198 | "有没有推荐的酒吧或品牌?" 199 | ], 200 | [ 201 | "你喜欢滑雪吗?", 202 | "平时去哪里滑雪?", 203 | "有没有什么滑雪技巧分享?" 204 | ], 205 | [ 206 | "你喜欢海边还是山里?", 207 | "最喜欢去哪个地方度假?", 208 | "有没有什么特别推荐的景点?" 209 | ], 210 | [ 211 | "你喜欢参加音乐节吗?", 212 | "参加过哪些音乐节?", 213 | "最喜欢的音乐节是哪一个?" 214 | ], 215 | [ 216 | "你喜欢跑步吗?", 217 | "平时跑多长距离?", 218 | "有没有参加过马拉松?" 219 | ], 220 | [ 221 | "你喜欢参加聚会吗?", 222 | "平时和朋友聚会做什么?", 223 | "有没有什么有趣的聚会游戏?" 224 | ], 225 | [ 226 | "你喜欢收集东西吗?", 227 | "收集什么类型的物品?", 228 | "有没有什么特别的收藏?" 229 | ] 230 | ] 231 | } -------------------------------------------------------------------------------- /ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | "gradient_accumulation_steps": "auto", 23 | "gradient_clipping": "auto", 24 | "steps_per_print": 2000, 25 | "train_batch_size": "auto", 26 | "train_micro_batch_size_per_gpu": "auto", 27 | "wall_clock_breakdown": false 28 | } -------------------------------------------------------------------------------- /img/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/1.png -------------------------------------------------------------------------------- /img/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/2.png -------------------------------------------------------------------------------- /img/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/3.png -------------------------------------------------------------------------------- /img/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/4.jpg -------------------------------------------------------------------------------- /img/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/img/5.png -------------------------------------------------------------------------------- /make_dataset/blocked_words.json: -------------------------------------------------------------------------------- 1 | { 2 | "blocked_words": [ 3 | "例如 姓名", 4 | "例如 地址", 5 | "//....." 6 | ] 7 | } -------------------------------------------------------------------------------- /make_dataset/csv_to_json-单句回答.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | 5 | import pandas as pd 6 | from collections import deque 7 | 8 | csv_folder = './data/csv' 9 | # csv_folder = './data/test' 10 | print(f'当前处理目录{csv_folder}') 11 | 12 | 13 | def handle_pt_csv(csvfile): 14 | chat_df = pd.read_csv(csvfile) 15 | # 选择type_name为文本的行、is_sender为1的行 16 | chat_df = chat_df[chat_df['type_name'] == '文本'] 17 | chat_df = chat_df[chat_df['is_sender'] == 1] 18 | # 对每一行的content进行处理 转为dict 再取'msg'字段 19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg']) 20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行 21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')] 22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')] 23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')] 24 | chat_df = chat_df[~chat_df['content'].str.contains('http')] 25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')] 26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')] 27 | 28 | # 纯content 29 | chat_df = chat_df['content'] 30 | chat_df = chat_df.dropna() 31 | 32 | return chat_df 33 | 34 | 35 | def make_pt_dataset(): 36 | csv_res = [] 37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 38 | for chat_obj_folder in os.listdir(csv_folder): 39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 40 | for csvfile in os.listdir(chat_obj_folder_path): 41 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 42 | chat_df = handle_pt_csv(csvfile_path) 43 | csv_res.append(chat_df) 44 | 45 | csv_res = pd.concat(csv_res) 46 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c 47 | 48 | csv_res.to_json('./data/res_csv/pt-my.json', orient='records', force_ascii=False) 49 | 50 | 51 | def handle_sft_csv(csvfile): 52 | chat_df = pd.read_csv(csvfile) 53 | blocked_words = json.load(open('./make_dataset/blocked_words.json', encoding='utf-8'))['blocked_words'] 54 | # 选择type_name为文本的行、is_sender为1的行 55 | # 需要保留的type_name字段名 56 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序'] 57 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)] 58 | 59 | # 对每一行的content进行处理 转为dict 再取'msg'字段 60 | chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg']) 61 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 用循环删除 62 | for i in chat_df.index: 63 | if chat_df.loc[i, 'type_name'] == '文本': 64 | if ('1\d{10}' in chat_df.loc[i, 'content'] or 65 | '\d{18}' in chat_df.loc[i, 'content'] or 66 | '\w+@\w+' in chat_df.loc[i, 'content'] or 67 | 'http' in chat_df.loc[i, 'content'] or 68 | r'\\xa0' in chat_df.loc[i, 'content'] or 69 | r'\\u' in chat_df.loc[i, 'content']): 70 | chat_df = chat_df.drop(index=i) 71 | continue 72 | for blocked_word in blocked_words: 73 | if blocked_word in chat_df.loc[i, 'content']: 74 | chat_df = chat_df.drop(index=i) 75 | break 76 | else: 77 | # content赋值为空 78 | chat_df.loc[i, 'content'] = '' 79 | 80 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']] 81 | chat_df = chat_df.dropna() 82 | 83 | # 时间格式 2021-07-07 10:27:23 84 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始 85 | # CreateTime字段保留最后的CreateTime 86 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime']) 87 | type_list.remove('文本') 88 | skip_list = type_list 89 | res_df = [] 90 | last_is_sender = chat_df.iloc[0]['is_sender'] 91 | last_content: str = chat_df.iloc[0]['content'] 92 | last_CreateTime = chat_df.iloc[0]['CreateTime'] 93 | # 超时处理 半天没说话就重新开始 94 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset 95 | 96 | # 遇到图片 连接 直接封装成一个句子 97 | for i, row in chat_df.iterrows(): 98 | if row['type_name'] in skip_list: 99 | if last_content != '': 100 | if last_content[-1] == ',': 101 | last_content = last_content[:-1] + '。' 102 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 103 | last_content += '。' 104 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 105 | last_CreateTime = row['CreateTime'] 106 | last_content = '' 107 | # cut表示被skip字段截断 108 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']}) 109 | continue 110 | if last_content == '': # 重新开始 111 | last_content = row['content'] 112 | last_is_sender = row['is_sender'] 113 | last_CreateTime = row['CreateTime'] 114 | continue 115 | if row['is_sender'] == last_is_sender: 116 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'): 117 | # 如果超时 前面的添加到res_df 并重新开始 118 | if last_content[-1] == ',': 119 | last_content = last_content[:-1] + '。' 120 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 121 | last_content += '。' 122 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 123 | last_content = row['content'] 124 | last_CreateTime = row['CreateTime'] 125 | continue 126 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号 127 | if last_content[-1] not in ['。', '!', '?', '…', ',']: 128 | last_content += ',' 129 | last_content = last_content + row['content'] 130 | last_CreateTime = row['CreateTime'] 131 | else: 132 | if last_content[-1] == ',': 133 | last_content = last_content[:-1] + '。' 134 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 135 | last_content += '。' 136 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 137 | last_is_sender = row['is_sender'] 138 | last_content = row['content'] 139 | last_CreateTime = row['CreateTime'] 140 | res_df = pd.DataFrame(res_df) 141 | return res_df 142 | 143 | 144 | def make_sft_dataset(): 145 | 146 | # [ 147 | # { 148 | # "instruction": "用户指令(必填)", 149 | # "input": "用户输入(选填)", 150 | # "output": "模型回答(必填)", 151 | # "system": "系统提示词(选填)", 152 | # "history": [ 153 | # ["第一轮指令(选填)", "第一轮回答(选填)"], 154 | # ["第二轮指令(选填)", "第二轮回答(选填)"] 155 | # ] 156 | # } 157 | # ] 158 | 159 | csv_concat = [] 160 | csv_res = [] 161 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 162 | for chat_obj_folder in os.listdir(csv_folder): 163 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 164 | for csvfile in os.listdir(chat_obj_folder_path): 165 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 166 | chat_df = handle_sft_csv(csvfile_path) 167 | csv_concat.append(chat_df) 168 | 169 | csv_concat = pd.concat(csv_concat) 170 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01 171 | # 相差超过1小时的时间戳分为不同的对话 172 | # temp_res为一个长度为2的队列 173 | temp_res = deque(maxlen=2) 174 | # 6种情况 175 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理 176 | # temp_res 有0 遇到0清空队列再入队 遇到1相差超过1小时清空队列 没有相差一小时入队再全部出队 遇到cut清空队列 177 | 178 | # 选最长的做为问题的结果? 179 | 180 | for i, row in csv_concat.iterrows(): 181 | if len(temp_res) == 0: 182 | if row['content'] == 'cut': 183 | continue 184 | if row['is_sender'] == 0: 185 | temp_res.append(row['content']) 186 | last_CreateTime = row['CreateTime'] 187 | else: 188 | continue 189 | elif len(temp_res) == 1: 190 | if row['content'] == 'cut': 191 | temp_res.clear() 192 | last_CreateTime = row['CreateTime'] 193 | elif row['is_sender'] == 0: 194 | # 遇到0 清空队列再入队 195 | temp_res.clear() 196 | temp_res.append(row['content']) 197 | last_CreateTime = row['CreateTime'] 198 | else: 199 | if row['CreateTime'] - last_CreateTime > pd.Timedelta('1h'): 200 | # 相差超过1小时清空队列 201 | temp_res.clear() 202 | last_CreateTime = row['CreateTime'] 203 | else: 204 | # 没有相差一小时入队再全部出队 205 | temp_res.append(row['content']) 206 | temp_output_list = temp_res[1].split(',') 207 | output = max(temp_output_list, key=len)# 只选选最长的回答作为最终数据 208 | if output[-1] == '。': 209 | output = output[:-1] 210 | csv_res.append({'instruction': temp_res[0], 'output': output}) 211 | temp_res.clear() 212 | last_CreateTime = row['CreateTime'] 213 | 214 | csv_res_df = pd.DataFrame(csv_res) 215 | print(f'数据量:{csv_res_df.shape[0]}') 216 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False) 217 | 218 | 219 | if __name__ == '__main__': 220 | # make_pt_dataset() 221 | make_sft_dataset() 222 | -------------------------------------------------------------------------------- /make_dataset/csv_to_json-单句多轮.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | 5 | import pandas as pd 6 | from collections import deque 7 | 8 | csv_folder = './data/csv' 9 | # csv_folder = './data/test' 10 | print(f'当前处理目录{csv_folder}') 11 | 12 | 13 | def handle_pt_csv(csvfile): 14 | chat_df = pd.read_csv(csvfile) 15 | # 选择type_name为文本的行、is_sender为1的行 16 | chat_df = chat_df[chat_df['type_name'] == '文本'] 17 | chat_df = chat_df[chat_df['is_sender'] == 1] 18 | # 对每一行的content进行处理 转为dict 再取'msg'字段 19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg']) 20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行 21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')] 22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')] 23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')] 24 | chat_df = chat_df[~chat_df['content'].str.contains('http')] 25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')] 26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')] 27 | 28 | # 纯content 29 | chat_df = chat_df['content'] 30 | chat_df = chat_df.dropna() 31 | 32 | return chat_df 33 | 34 | 35 | def make_pt_dataset(): 36 | csv_res = [] 37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 38 | for chat_obj_folder in os.listdir(csv_folder): 39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 40 | for csvfile in os.listdir(chat_obj_folder_path): 41 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 42 | chat_df = handle_pt_csv(csvfile_path) 43 | csv_res.append(chat_df) 44 | 45 | csv_res = pd.concat(csv_res) 46 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c 47 | 48 | csv_res.to_json('./data/res_csv/pt.json', orient='records', force_ascii=False) 49 | 50 | 51 | def handle_sft_csv(csvfile): 52 | chat_df = pd.read_csv(csvfile) 53 | blocked_words = json.load(open('./make_dataset/blocked_words.json'))['blocked_words'] 54 | # 选择type_name为文本的行、is_sender为1的行 55 | # 需要保留的type_name字段名 56 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序'] 57 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)] 58 | 59 | # 对每一行的content进行处理 转为dict 再取'msg'字段 60 | chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg']) 61 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 用循环删除 62 | for i in chat_df.index: 63 | if chat_df.loc[i, 'type_name'] == '文本': 64 | if ('1\d{10}' in chat_df.loc[i, 'content'] or 65 | '\d{18}' in chat_df.loc[i, 'content'] or 66 | '\w+@\w+' in chat_df.loc[i, 'content'] or 67 | 'http' in chat_df.loc[i, 'content'] or 68 | r'\\xa0' in chat_df.loc[i, 'content'] or 69 | r'\\u' in chat_df.loc[i, 'content']): 70 | chat_df = chat_df.drop(index=i) 71 | continue 72 | for blocked_word in blocked_words: 73 | if blocked_word in chat_df.loc[i, 'content']: 74 | chat_df = chat_df.drop(index=i) 75 | break 76 | else: 77 | # content赋值为空 78 | chat_df.loc[i, 'content'] = '' 79 | 80 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']] 81 | chat_df = chat_df.dropna() 82 | 83 | # 时间格式 2021-07-07 10:27:23 84 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始 85 | # CreateTime字段保留最后的CreateTime 86 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime']) 87 | type_list.remove('文本') 88 | skip_list = type_list 89 | res_df = [] 90 | last_is_sender = chat_df.iloc[0]['is_sender'] 91 | last_content: str = chat_df.iloc[0]['content'] 92 | last_CreateTime = chat_df.iloc[0]['CreateTime'] 93 | # 超时处理 半天没说话就重新开始 94 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset 95 | 96 | # 遇到图片 连接 直接封装成一个句子 97 | for i, row in chat_df.iterrows(): 98 | if '跟最终成绩差1分' in row['content']: 99 | pass 100 | if row['type_name'] in skip_list: 101 | if last_content != '': 102 | if last_content[-1] == ',': 103 | last_content = last_content[:-1] + '。' 104 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 105 | last_content += '。' 106 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 107 | last_CreateTime = row['CreateTime'] 108 | last_content = '' 109 | # cut表示被skip字段截断 110 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']}) 111 | continue 112 | if last_content == '': # 重新开始 113 | last_content = row['content'] 114 | last_is_sender = row['is_sender'] 115 | last_CreateTime = row['CreateTime'] 116 | continue 117 | if row['is_sender'] == last_is_sender: 118 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'): 119 | # 如果超时 前面的添加到res_df 并重新开始 120 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 121 | res_df.append({'is_sender': last_is_sender, 'content': 'cut', 'CreateTime': last_CreateTime}) 122 | last_content = row['content'] 123 | last_CreateTime = row['CreateTime'] 124 | continue 125 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号 126 | if last_content[-1] not in ['。', '!', '?', '…', ',']: 127 | last_content += ',' 128 | last_content = last_content + row['content'] 129 | last_CreateTime = row['CreateTime'] 130 | else: 131 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 132 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'): 133 | res_df.append({'is_sender': last_is_sender, 'content': 'cut', 'CreateTime': last_CreateTime}) 134 | last_is_sender = row['is_sender'] 135 | last_content = row['content'] 136 | last_CreateTime = row['CreateTime'] 137 | res_df = pd.DataFrame(res_df) 138 | return res_df 139 | 140 | 141 | def make_sft_dataset(): 142 | 143 | # [ 144 | # { 145 | # "instruction": "用户指令(必填)", 146 | # "input": "用户输入(选填)", 147 | # "output": "模型回答(必填)", 148 | # "system": "系统提示词(选填)", 149 | # "history": [ 150 | # ["第一轮指令(选填)", "第一轮回答(选填)"], 151 | # ["第二轮指令(选填)", "第二轮回答(选填)"] 152 | # ] 153 | # } 154 | # ] 155 | 156 | csv_concat = [] 157 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 158 | for chat_obj_folder in os.listdir(csv_folder): 159 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 160 | for csvfile in os.listdir(chat_obj_folder_path): 161 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 162 | chat_df = handle_sft_csv(csvfile_path) 163 | csv_concat.append(chat_df) 164 | 165 | csv_concat = pd.concat(csv_concat) 166 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01 167 | # 相差超过1小时的时间戳分为不同的对话 168 | # temp_res为一个长度为2的队列 169 | temp_chat_item = deque(maxlen=2) 170 | qa_list = [] 171 | 172 | # 6种情况 173 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理 174 | # temp_res 有0 遇到0清空队列再入队 遇到1入队再全部出队 遇到cut清空队列 175 | 176 | # 选最长的做为问题的结果? 177 | 178 | for i, row in csv_concat.iterrows(): 179 | if len(temp_chat_item) == 0: 180 | if row['content'] == 'cut': 181 | qa_list.append('cut') 182 | continue 183 | if row['is_sender'] == 0: 184 | temp_chat_item.append(row['content']) 185 | else: 186 | continue 187 | elif len(temp_chat_item) == 1: 188 | if row['content'] == 'cut': 189 | qa_list.append('cut') 190 | temp_chat_item.clear() 191 | elif row['is_sender'] == 0: 192 | # 遇到0 清空队列再入队 193 | temp_chat_item.clear() 194 | temp_chat_item.append(row['content']) 195 | else: 196 | temp_chat_item.append(row['content']) 197 | temp_output_list = temp_chat_item[1].split(',') 198 | output = max(temp_output_list, key=len) 199 | if output[-1] == '。': 200 | output = output[:-1] 201 | qa_list.append({'instruction': temp_chat_item[0], 'output': output}) 202 | # csv_res.append({'instruction': temp_chat_item[0], 'output': output, 'system': '请你扮演一名人类,不要说自己是人工智能'}) 203 | temp_chat_item.clear() 204 | 205 | csv_res = [] 206 | system_prompt = '请你扮演一名人类,不要说自己是人工智能或者某个模型' 207 | 208 | last_res = {'instruction': '', 'output': '', 209 | 'system': system_prompt, 210 | 'history': []} 211 | for i, qa in enumerate(qa_list): 212 | if qa == 'cut': 213 | if len(last_res['history']) == 0: 214 | continue 215 | else: 216 | if len(last_res['history']) == 1: 217 | last_res = {'instruction': last_res['history'][0][0], 'output': last_res['history'][0][1], 218 | 'system': system_prompt,'history': []} 219 | else: 220 | last_res = {'instruction': last_res['history'][-1][0], 'output': last_res['history'][-1][1], 221 | 'system': system_prompt, 222 | 'history': last_res['history'][:-1]} 223 | csv_res.append(last_res) 224 | last_res = {'instruction': '', 'output': '', 225 | 'system': system_prompt, 226 | 'history': []} 227 | else: 228 | last_res['history'].append([qa['instruction'], qa['output']]) 229 | 230 | 231 | csv_res_df = pd.DataFrame(csv_res) 232 | print(f'数据量:{csv_res_df.shape[0]}') 233 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False) 234 | 235 | 236 | if __name__ == '__main__': 237 | # make_pt_dataset() 238 | make_sft_dataset() 239 | -------------------------------------------------------------------------------- /make_dataset/csv_to_json.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import json 3 | import os 4 | 5 | import pandas as pd 6 | from collections import deque 7 | 8 | csv_folder = './data/csv' 9 | # csv_folder = './data/test' 10 | print(f'当前处理目录{csv_folder}') 11 | 12 | 13 | def handle_pt_csv(csvfile): 14 | chat_df = pd.read_csv(csvfile) 15 | # 选择type_name为文本的行、is_sender为1的行 16 | chat_df = chat_df[chat_df['type_name'] == '文本'] 17 | chat_df = chat_df[chat_df['is_sender'] == 1] 18 | # 对每一行的content进行处理 转为dict 再取'msg'字段 19 | chat_df['content'] = chat_df['content'].apply(lambda x: json.loads(x)['msg']) 20 | # 如果content 包含 手机号、身份证号、邮箱、网址则删除这行 21 | chat_df = chat_df[~chat_df['content'].str.contains('1\d{10}')] 22 | chat_df = chat_df[~chat_df['content'].str.contains('\d{18}')] 23 | chat_df = chat_df[~chat_df['content'].str.contains('\w+@\w+')] 24 | chat_df = chat_df[~chat_df['content'].str.contains('http')] 25 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\xa0')] 26 | chat_df = chat_df[~chat_df['content'].str.contains(r'\\u')] 27 | 28 | # 纯content 29 | chat_df = chat_df['content'] 30 | chat_df = chat_df.dropna() 31 | 32 | return chat_df 33 | 34 | 35 | def make_pt_dataset(): 36 | csv_res = [] 37 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 38 | for chat_obj_folder in os.listdir(csv_folder): 39 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 40 | for csvfile in os.listdir(chat_obj_folder_path): 41 | if not csvfile.endswith('.csv'): 42 | continue 43 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 44 | chat_df = handle_pt_csv(csvfile_path) 45 | csv_res.append(chat_df) 46 | 47 | csv_res = pd.concat(csv_res) 48 | csv_res = csv_res.apply(lambda x: {'c': x}) # 设置数据集prompt键为c 49 | 50 | csv_res.to_json('./data/res_csv/pt-my.json', orient='records', force_ascii=False) 51 | 52 | 53 | def handle_sft_csv(csvfile): 54 | chat_df = pd.read_csv(csvfile) 55 | blocked_words = json.load(open('./make_dataset/blocked_words.json', encoding='utf-8'))['blocked_words'] 56 | # 选择type_name为文本的行、is_sender为1的行 57 | # 需要保留的type_name字段名 58 | type_list = ['文本', '图片', '卡片式链接', '合并转发的聊天记录', '视频', '语言', '未知', '分享的小程序'] 59 | chat_df = chat_df[chat_df['type_name'].isin(values=type_list)] 60 | 61 | # chat_df['content'] = chat_df['content'].apply(func=lambda x: json.loads(x)['msg']) 62 | chat_df['content'] = chat_df['msg'] 63 | 64 | # 如果type_name为文本 并且content 包含 手机号、身份证号、邮箱、网址则删除这行 65 | for i in chat_df.index: 66 | if chat_df.loc[i, 'type_name'] == '文本': 67 | if ('1\d{10}' in chat_df.loc[i, 'content'] or 68 | '\d{18}' in chat_df.loc[i, 'content'] or 69 | '\w+@\w+' in chat_df.loc[i, 'content'] or 70 | 'http' in chat_df.loc[i, 'content'] or 71 | r'\\xa0' in chat_df.loc[i, 'content'] or 72 | r'\\u' in chat_df.loc[i, 'content']): 73 | chat_df = chat_df.drop(index=i) 74 | continue 75 | for blocked_word in blocked_words: 76 | if blocked_word in chat_df.loc[i, 'content']: 77 | chat_df = chat_df.drop(index=i) 78 | break 79 | else: 80 | chat_df.loc[i, 'content'] = '' 81 | 82 | chat_df = chat_df[['is_sender', 'type_name', 'content', 'CreateTime']] 83 | chat_df = chat_df.dropna() 84 | 85 | # 时间格式 2021-07-07 10:27:23 86 | # 遍历行 相同is_sender的行合并content()遇到不同is_sender就重新开始 87 | # CreateTime字段保留最后的CreateTime 88 | chat_df['CreateTime'] = pd.to_datetime(chat_df['CreateTime']) 89 | type_list.remove('文本') 90 | skip_list = type_list 91 | res_df = [] 92 | last_is_sender = chat_df.iloc[0]['is_sender'] 93 | last_content: str = chat_df.iloc[0]['content'] 94 | last_CreateTime = chat_df.iloc[0]['CreateTime'] 95 | # 超时处理 半天没说话就重新开始 96 | # 注意这里只是处理了组装成一个句子 最后封装对话、配对在make_sft_dataset 97 | # 遇到图片 连接 直接封装成一个句子 98 | for i, row in chat_df.iterrows(): 99 | if row['type_name'] in skip_list: 100 | if last_content != '': 101 | if last_content[-1] == ',': 102 | last_content = last_content[:-1] + '。' 103 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 104 | last_content += '。' 105 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 106 | last_CreateTime = row['CreateTime'] 107 | last_content = '' 108 | # cut表示被skip字段截断 109 | res_df.append({'is_sender': row['is_sender'], 'content': 'cut', 'CreateTime': row['CreateTime']}) 110 | continue 111 | if last_content == '': # 重新开始 112 | last_content = row['content'] 113 | last_is_sender = row['is_sender'] 114 | last_CreateTime = row['CreateTime'] 115 | continue 116 | if row['is_sender'] == last_is_sender: 117 | if row['CreateTime'] - last_CreateTime > pd.Timedelta(value='1h'): 118 | # 如果超时 前面的添加到res_df 并重新开始 119 | if last_content[-1] == ',': 120 | last_content = last_content[:-1] + '。' 121 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 122 | last_content += '。' 123 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 124 | last_content = row['content'] 125 | last_CreateTime = row['CreateTime'] 126 | continue 127 | # 如果content的结尾没有标点符号则添加逗号,最后结尾是句号 128 | if last_content[-1] not in ['。', '!', '?', '…', ',']: 129 | last_content += ',' 130 | last_content = last_content + row['content'] 131 | last_CreateTime = row['CreateTime'] 132 | else: 133 | if last_content[-1] == ',': 134 | last_content = last_content[:-1] + '。' 135 | elif last_content[-1] not in ['。', '!', '?', '…', '.']: 136 | last_content += '。' 137 | res_df.append({'is_sender': last_is_sender, 'content': last_content, 'CreateTime': last_CreateTime}) 138 | last_is_sender = row['is_sender'] 139 | last_content = row['content'] 140 | last_CreateTime = row['CreateTime'] 141 | res_df = pd.DataFrame(res_df) 142 | return res_df 143 | 144 | 145 | def make_sft_dataset(): 146 | 147 | # [ 148 | # { 149 | # "instruction": "用户指令(必填)", 150 | # "input": "用户输入(选填)", 151 | # "output": "模型回答(必填)", 152 | # "system": "系统提示词(选填)", 153 | # "history": [ 154 | # ["第一轮指令(选填)", "第一轮回答(选填)"], 155 | # ["第二轮指令(选填)", "第二轮回答(选填)"] 156 | # ] 157 | # } 158 | # ] 159 | 160 | csv_concat = [] 161 | csv_res = [] 162 | # csv文件夹里全是不同聊天对象文件夹 每个文件夹里是csv文件 先遍历不同聊天对象文件夹 再遍历聊天对象的csv文件 163 | for chat_obj_folder in os.listdir(csv_folder): 164 | chat_obj_folder_path = os.path.join(csv_folder, chat_obj_folder) 165 | for csvfile in os.listdir(chat_obj_folder_path): 166 | if not csvfile.endswith('.csv'): 167 | continue 168 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 169 | chat_df = handle_sft_csv(csvfile_path) 170 | csv_concat.append(chat_df) 171 | 172 | csv_concat = pd.concat(csv_concat) 173 | # csv_res里is_sender必须是01 01 01 的顺序 csv_concat里不一定是01 01 174 | # 相差超过1小时的时间戳分为不同的对话 175 | # temp_res为一个长度为2的队列 176 | temp_res = deque(maxlen=2) 177 | # 6种情况 178 | # temp_res 为空 遇到 0入队 遇到1不处理 遇到cut不处理 179 | # temp_res 有0 遇到0清空队列再入队 遇到1相差超过1小时清空队列 没有相差一小时入队再全部出队 遇到cut清空队列 180 | 181 | for i, row in csv_concat.iterrows(): 182 | if len(temp_res) == 0: 183 | if row['content'] == 'cut': 184 | continue 185 | if row['is_sender'] == 0: 186 | temp_res.append(row['content']) 187 | last_CreateTime = row['CreateTime'] 188 | else: 189 | continue 190 | elif len(temp_res) == 1: 191 | if row['content'] == 'cut': 192 | temp_res.clear() 193 | last_CreateTime = row['CreateTime'] 194 | elif row['is_sender'] == 0: 195 | # 遇到0 清空队列再入队 196 | temp_res.clear() 197 | temp_res.append(row['content']) 198 | last_CreateTime = row['CreateTime'] 199 | else: 200 | if row['CreateTime'] - last_CreateTime > pd.Timedelta('1h'): 201 | # 相差超过1小时清空队列 202 | temp_res.clear() 203 | last_CreateTime = row['CreateTime'] 204 | else: 205 | # 没有相差一小时入队再全部出队 206 | temp_res.append(row['content']) 207 | csv_res.append({'instruction': temp_res[0], 'output': temp_res[1]}) 208 | temp_res.clear() 209 | last_CreateTime = row['CreateTime'] 210 | 211 | 212 | csv_res_df = pd.DataFrame(csv_res) 213 | print(f'处理后数据量:{csv_res_df.shape[0]}') 214 | csv_res_df.to_json('./data/res_csv/sft/sft-my.json', orient='records', force_ascii=False) 215 | 216 | 217 | if __name__ == '__main__': 218 | # make_pt_dataset() 219 | make_sft_dataset() 220 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "WeClone" 3 | version = "0.1.2" 4 | description = "" 5 | authors = [{ name = "xming521" }] 6 | readme = "README.md" 7 | requires-python = ">=3.9,<3.10" 8 | dependencies = ["pandas", "pydantic==2.10.6"] 9 | 10 | [dependency-groups] 11 | xcodec = ["xcodec2==0.1.3"] 12 | wx = ["pywxdump"] 13 | sparktts = [ 14 | "einops>=0.8.1", 15 | "einx>=0.3.0", 16 | "numpy==1.26.4", 17 | "omegaconf>=2.3.0", 18 | "packaging>=24.2", 19 | "safetensors>=0.5.2", 20 | "soundfile>=0.12.1", 21 | "soxr>=0.5.0.post1", 22 | "torch>=2.5.1", 23 | "torchaudio>=2.5.1", 24 | "tqdm>=4.66.5", 25 | "transformers==4.45.2", 26 | ] 27 | main = ["transformers==4.45.2", "llamafactory>=0.9.2", "openai==0.28.0"] 28 | 29 | [tool.uv] 30 | conflicts = [ 31 | [ 32 | { group = "wx" }, 33 | { group = "sparktts" }, 34 | ], 35 | [ 36 | { group = "wx" }, 37 | { group = "main" }, 38 | ], 39 | [ 40 | { group = "wx" }, 41 | { group = "xcodec" }, 42 | ], 43 | ] 44 | 45 | [[tool.uv.index]] 46 | url = "https://pypi.tuna.tsinghua.edu.cn/simple/" 47 | default = true 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # LLaMA-Factory 2 | llmtuner 3 | # wechat 4 | itchat-uos==1.5.0.dev0 5 | # others 6 | pandas 7 | # chromadb 8 | # langchain 9 | openai==0.28 10 | 11 | 12 | -------------------------------------------------------------------------------- /settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_pt_args": { 3 | "stage": "pt", 4 | "dataset": "wechat-pt", 5 | "dataset_dir": "./data/res_csv/pt", 6 | "lora_target": "query_key_value", 7 | "lora_rank": 2, 8 | "lora_dropout": 0.1, 9 | "output_dir": "model_output", 10 | "overwrite_cache": true, 11 | "per_device_train_batch_size": 1, 12 | "gradient_accumulation_steps": 1, 13 | "lr_scheduler_type": "cosine", 14 | "logging_steps": 10, 15 | "save_steps": 1000, 16 | "learning_rate": 0.001, 17 | "num_train_epochs": 30, 18 | "plot_loss": true, 19 | "fp16": true 20 | }, 21 | "train_sft_args": { 22 | "stage": "sft", 23 | "dataset": "wechat-sft", 24 | "dataset_dir": "./data/res_csv/sft", 25 | "lora_target": "query_key_value", 26 | "lora_rank": 4, 27 | "lora_dropout": 0.5, 28 | "overwrite_cache": true, 29 | "per_device_train_batch_size": 4, 30 | "gradient_accumulation_steps": 8, 31 | "lr_scheduler_type": "cosine", 32 | "logging_steps": 10, 33 | "save_steps": 150, 34 | "learning_rate": 0.0001, 35 | "num_train_epochs": 3, 36 | "plot_loss": true, 37 | "fp16": true 38 | }, 39 | "infer_args": { 40 | "repetition_penalty": 1.2, 41 | "temperature": 0.5, 42 | "max_length": 50, 43 | "top_p": 0.65 44 | }, 45 | "common_args": { 46 | "model_name_or_path": "./chatglm3-6b", 47 | "adapter_name_or_path": "./model_output", 48 | "template": "chatglm3-weclone", 49 | "finetuning_type": "lora", 50 | "trust_remote_code": true 51 | }, 52 | "_comment": "adapter_name_or_path同时做为train_sft_args的output_dir " 53 | } -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/src/__init__.py -------------------------------------------------------------------------------- /src/api_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uvicorn 3 | from llamafactory.chat import ChatModel 4 | from llamafactory.api.app import create_app 5 | from template import template_register 6 | from utils.config import load_config 7 | 8 | config = load_config('api_service') 9 | 10 | 11 | template_register() 12 | 13 | 14 | def main(): 15 | chat_model = ChatModel(config) 16 | app = create_app(chat_model) 17 | print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8005))) 18 | uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8005)), workers=1) 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /src/cli_demo.py: -------------------------------------------------------------------------------- 1 | from llamafactory.chat import ChatModel 2 | from llamafactory.extras.misc import torch_gc 3 | 4 | 5 | try: 6 | import platform 7 | 8 | if platform.system() != "Windows": 9 | import readline # noqa: F401 10 | except ImportError: 11 | print("Install `readline` for a better experience.") 12 | 13 | 14 | def main(): 15 | chat_model = ChatModel() 16 | messages = [] 17 | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") 18 | 19 | while True: 20 | try: 21 | query = input("\nUser: ") 22 | except UnicodeDecodeError: 23 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") 24 | continue 25 | except Exception: 26 | raise 27 | 28 | if query.strip() == "exit": 29 | break 30 | 31 | if query.strip() == "clear": 32 | messages = [] 33 | torch_gc() 34 | print("History has been removed.") 35 | continue 36 | 37 | messages.append({"role": "user", "content": query}) 38 | print("Assistant: ", end="", flush=True) 39 | 40 | response = "" 41 | for new_text in chat_model.stream_chat(messages): 42 | print(new_text, end="", flush=True) 43 | response += new_text 44 | print() 45 | messages.append({"role": "assistant", "content": response}) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | from llamafactory.eval.evaluator import Evaluator 2 | 3 | 4 | def main(): 5 | evaluator = Evaluator() 6 | evaluator.eval() 7 | 8 | 9 | if __name__ == "__main__": 10 | main() 11 | -------------------------------------------------------------------------------- /src/export_model.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import export_model 2 | 3 | 4 | def main(): 5 | export_model() 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /src/template.py: -------------------------------------------------------------------------------- 1 | from llamafactory.data.formatter import FunctionFormatter, StringFormatter, ToolFormatter, EmptyFormatter 2 | from llamafactory.data.template import register_template 3 | 4 | default_prompt = "请你扮演一名人类,不要说自己是人工智能" 5 | 6 | 7 | def template_register(): 8 | register_template( 9 | name="chatglm3-weclone", 10 | default_system=( 11 | default_prompt 12 | ), 13 | format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), 14 | format_assistant=StringFormatter(slots=["\n", "{{content}}"]), 15 | format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), 16 | format_function=FunctionFormatter(slots=["{{content}}"], tool_format="glm4"), 17 | format_observation=StringFormatter( 18 | slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] 19 | ), 20 | format_tools=ToolFormatter(tool_format="glm4"), 21 | format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), 22 | stop_words=["<|user|>", "<|observation|>"], 23 | efficient_eos=True, 24 | ) 25 | -------------------------------------------------------------------------------- /src/test_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import sys 5 | import openai 6 | sys.path.append(os.getcwd()) 7 | import tqdm 8 | from src.template import default_prompt 9 | from tqdm import tqdm 10 | 11 | 12 | config = { 13 | 'default_prompt': default_prompt, 14 | 'model': 'gpt-3.5-turbo', 15 | 'history_len': 15, 16 | } 17 | 18 | config = type('Config', (object,), config)() 19 | 20 | openai.api_key = '''sk-test''' 21 | openai.api_base = "http://127.0.0.1:8000/v1" 22 | 23 | 24 | def handler_text(content: str, history: [], config): 25 | 26 | messages = [{"role": "system", "content": f'{config.default_prompt}'}] 27 | for item in history: 28 | messages.append(item) 29 | messages.append({"role": "user", "content": content}) 30 | history.append({"role": "user", "content": content}) 31 | try: 32 | response = openai.ChatCompletion.create(model=config.model, 33 | messages=messages, 34 | max_tokens=50) 35 | except openai.APIError as e: 36 | history.pop() 37 | return 'AI接口出错,请重试\n' + str(e) 38 | 39 | resp = str(response.choices[0].message.content) 40 | resp = resp.replace('\n ', '') 41 | history.append({"role": "assistant", "content": resp}) 42 | return resp 43 | 44 | 45 | def main(): 46 | test_list = json.loads(open('data/test_data.json').read())['questions'] 47 | res = [] 48 | for questions in tqdm(test_list, desc=' Testing...'): 49 | history = [] 50 | for q in questions: 51 | answer = handler_text(q, history=history, config=config) 52 | res.append(history) 53 | 54 | res_file = open('test_result-my.txt', 'w') 55 | for r in res: 56 | for i in r: 57 | res_file.write(i['content'] + '\n') 58 | res_file.write('\n') 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /src/train_pt.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import run_exp 2 | from utils.config import load_config 3 | 4 | config = load_config('train_pt') 5 | run_exp(config) 6 | -------------------------------------------------------------------------------- /src/train_sft.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import run_exp 2 | from template import template_register 3 | from utils.config import load_config 4 | 5 | config = load_config(arg_type='train_sft') 6 | 7 | template_register() 8 | 9 | run_exp(config) 10 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mtcto/weclone/4ec5cf8d2270738101394f993409e0f14c8ea8a3/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from .utils import dict_to_argv 3 | import sys 4 | 5 | 6 | def load_config(arg_type: str): 7 | with open('./settings.json', 'r') as f: 8 | config: dict = json.load(f) 9 | if arg_type == 'web_demo' or arg_type == 'api_service': 10 | # infer_args和common_args求并集 11 | config = {**config['infer_args'], **config['common_args']} 12 | elif arg_type == 'train_pt': 13 | config = {**config['train_pt_args'], **config['common_args']} 14 | elif arg_type == 'train_sft': 15 | config = {**config['train_sft_args'], **config['common_args']} 16 | else: 17 | raise ValueError('暂不支持的类型') 18 | 19 | if 'train' in arg_type: 20 | config['output_dir'] = config['adapter_name_or_path'] 21 | config.pop('adapter_name_or_path') 22 | config['do_train'] = True 23 | 24 | sys.argv += dict_to_argv(config) 25 | 26 | return config 27 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | def dict_to_argv(d): 2 | argv = [] 3 | for k, v in d.items(): 4 | argv.append('--' + k) 5 | if v is not None: 6 | argv.append(str(v)) 7 | return argv 8 | 9 | 10 | -------------------------------------------------------------------------------- /src/web_demo.py: -------------------------------------------------------------------------------- 1 | from llamafactory.webui.interface import create_web_demo 2 | from template import template_register 3 | from utils.config import load_config 4 | 5 | config = load_config('web_demo') 6 | 7 | template_register() 8 | 9 | def main(): 10 | demo = create_web_demo() 11 | demo.queue() 12 | demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) 13 | 14 | 15 | if __name__ == "__main__": 16 | main() 17 | -------------------------------------------------------------------------------- /src/wechat_bot/handler/text.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import openai 4 | 5 | 6 | log = logging.getLogger('text') 7 | 8 | 9 | def handler_text(content: str, history: [], config): 10 | # todo 收到/clear清理历史记录 11 | # try: 12 | # history.clear() 13 | # return '清理完毕!' 14 | # except KeyError: 15 | # return '不存在消息记录,无需清理' 16 | 17 | messages = [{"role": "system", "content": f'{config.default_prompt}'}] 18 | for item in history: 19 | messages.append(item) 20 | messages.append({"role": "user", "content": content}) 21 | history.append({"role": "user", "content": content}) 22 | try: 23 | response = openai.ChatCompletion.create(model=config.model, 24 | messages=messages, 25 | max_tokens=50) 26 | except openai.APIError as e: 27 | log.error(e) 28 | history.pop() 29 | return 'AI接口出错,请重试\n' + str(e) 30 | 31 | resp = str(response.choices[0].message.content) 32 | resp = resp.replace('\n ', '') 33 | history.append({"role": "assistant", "content": resp}) 34 | return resp 35 | -------------------------------------------------------------------------------- /src/wechat_bot/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import signal 4 | import sys 5 | import time 6 | import xml.etree.ElementTree as ET 7 | sys.path.append(os.getcwd()) 8 | import openai 9 | import requests 10 | 11 | import itchat 12 | from handler.text import handler_text 13 | from itchat import utils 14 | from itchat.content import * 15 | from src.template import default_prompt 16 | import logging 17 | 18 | # logging.basicConfig(level=logging.INFO) 19 | log = logging.getLogger('main') 20 | 21 | config = { 22 | 'default_prompt': default_prompt, 23 | 'model': 'gpt-3.5-turbo', 24 | 'history_len': 15, 25 | } 26 | 27 | 28 | config = type('Config', (object,), config)() 29 | 30 | 31 | def stop_program(signal, frame): 32 | log.info('WeChatbot Closing Save some data') 33 | itchat.dump_login_status() 34 | sys.exit(0) 35 | 36 | 37 | signal.signal(signal.SIGTERM, stop_program) 38 | 39 | 40 | class WeChatGPT: 41 | 42 | def __init__(self): 43 | itchat.auto_login(enableCmdQR=2, hotReload=True, statusStorageDir='./cookie.bin') 44 | 45 | self.history = {} 46 | self.prompts = {} 47 | openai.api_key = '''sk-test''' 48 | openai.api_base = "http://127.0.0.1:8000/v1" 49 | 50 | log.info("init successful!") 51 | 52 | def handler_history(self, msg): 53 | self.history.setdefault(msg.user.userName, []) 54 | history = self.history[msg.user.userName] 55 | need_remove_len = len(history) - config.history_len 56 | if need_remove_len > 0: 57 | for i in range(need_remove_len): 58 | # 必须出一对 59 | history.pop(0) 60 | history.pop(0) 61 | return history 62 | 63 | def reply(self, msg): 64 | if time.time() - msg.CreateTime > 5: 65 | return None 66 | res = handler_text(content=msg.text, history=self.handler_history(msg), config=config) 67 | res = res.split(',') 68 | res[-1] = res[-1].replace('。', '') 69 | if res[0] == '': 70 | res[0] = '机器人他无语了' 71 | for r in res: 72 | msg.user.send(r) 73 | time.sleep(2.2) 74 | 75 | def run(self): 76 | @itchat.msg_register(FRIENDS) 77 | def add_friend(msg): 78 | """自动同意好友""" 79 | root = ET.fromstring(msg.content) 80 | ticket = root.get('ticket') 81 | # itchat.accept_friend(msg.user.userName, ticket) 82 | 83 | @itchat.msg_register(TEXT) 84 | def friend(msg): 85 | """处理私聊消息""" 86 | log.info(f"{msg.user.NickName}: {msg.text}") 87 | self.reply(msg) 88 | 89 | @itchat.msg_register(TEXT, isGroupChat=True) 90 | def groups(msg): 91 | """处理群聊消息""" 92 | if msg.isAt: 93 | self.reply(msg) 94 | 95 | itchat.run(debug=True) 96 | 97 | 98 | if __name__ == "__main__": 99 | 100 | try: 101 | weChatGPT = WeChatGPT() 102 | weChatGPT.run() 103 | except KeyboardInterrupt: 104 | log.info("bye!") 105 | --------------------------------------------------------------------------------