├── .cursor └── rules │ └── weclone-rules.mdc ├── .github ├── issue-labeler.yml └── workflows │ └── issue-labeler.yml ├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── res_csv │ ├── pt │ │ └── dataset_info.json │ └── sft │ │ └── dataset_info.json ├── test_data-privacy.json └── test_data.json ├── ds_config.json ├── pyproject.toml ├── settings.template.jsonc ├── tests ├── __init__.py ├── full_pipe.jsonc └── test_full_pipe.py ├── weclone-audio ├── README.md └── src │ ├── Llasa │ ├── infer.py │ └── text_to_speech.py │ ├── SparkTTS.py │ ├── __init__.py │ ├── get_sample_audio.py │ ├── infer.py │ ├── sample.wav │ └── server未完工 │ ├── .env.example │ ├── handle_text.py │ ├── requirements.txt │ ├── server.py │ ├── tts_handler.py │ └── utils.py └── weclone ├── __init__.py ├── cli.py ├── core └── inference │ ├── offline_infer.py │ └── online_infer.py ├── data ├── __init__.py ├── chat_parsers │ └── wechat_parser.py ├── clean │ ├── __init__.py │ ├── get_score.py │ ├── strategies.py │ └── strategies_online.py ├── models.py ├── qa_generator.py └── strategies.py ├── eval ├── __init__.py ├── cli_demo.py ├── eval_model.py ├── test_model.py └── web_demo.py ├── prompts ├── __init__.py └── clean_data.py ├── server ├── __init__.py └── api_service.py ├── train ├── __init__.py ├── export_model.py ├── train_pt.py └── train_sft.py └── utils ├── __init__.py ├── config.py ├── length_cdf.py ├── log.py └── tools.py /.cursor/rules/weclone-rules.mdc: -------------------------------------------------------------------------------- 1 | --- 2 | description: 3 | globs: 4 | alwaysApply: true 5 | --- 6 | --- 7 | description: 8 | globs: 9 | alwaysApply: true 10 | --- 11 | 12 | # Your rule content 13 | - You can @ files here 14 | - The project uses uv as the package manager and pyproject.toml as the project configuration file. 15 | - Unless I ask you to, code comments don't need to be excessive. 16 | - Prefer using the encapsulated logger `from weclone.utils.log import logger` for printing. 17 | - When retrieving values from a parameter dictionary read from a configuration file, the `get` method should be preferred whenever possible. 18 | 19 | 20 | 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /.github/issue-labeler.yml: -------------------------------------------------------------------------------- 1 | # 添加 Discussion 标签 2 | Discussion: 3 | - '(讨论|交流|分享|意见|建议|思考|探讨|交换意见|brainstorm|discussion)' 4 | 5 | # 添加 bug 标签 6 | bug: 7 | - '(bug|错误|问题|失败|崩溃|异常|报错|不工作|无法运行|broken|crash|error|exception|fails)' 8 | 9 | # 添加 chatbot 标签 10 | chatbot: 11 | - '(聊天机器人|chatbot|chat bot|对话机器人|聊天助手|AI助手|机器人对话|bot|assistant)' 12 | 13 | # 添加 documentation 标签 14 | documentation: 15 | - '(文档|说明|使用指南|指导|手册|教程|文档更新|documentation|docs|guide|tutorial|readme)' 16 | 17 | # 添加 duplicate 标签 18 | duplicate: 19 | - '(重复|已有|duplicate|已经存在|已提交过|重复问题|重复报告|dup)' 20 | 21 | # 添加 feature 标签 22 | feature: 23 | - '(功能|特性|新增|增加|添加|实现|feature|enhancement|新功能|功能请求|feature request)' 24 | 25 | # 添加 good first issue 标签 26 | good first issue: 27 | - '(入门|简单|容易|新手|初学者|开始|first|beginner|starter|easy|简单任务|good first issue)' 28 | 29 | # 添加 help wanted 标签 30 | help wanted: 31 | - '(需要帮助|寻求帮助|请求协助|help|求助|协助|帮忙|help wanted|need help|assistance)' 32 | 33 | # 添加 invalid 标签 34 | invalid: 35 | - '(无效|不适用|不相关|无关|错误提交|invalid|not relevant|irrelevant|not applicable)' 36 | 37 | # 添加 Mac 标签 38 | Mac: 39 | - '(Mac|MacOS|macOS|OSX|Mac系统|苹果系统|苹果电脑|MacBook)' 40 | 41 | # 添加 question 标签 42 | question: 43 | - '(问题|疑问|如何|怎么|请问|是否|能否|可以吗|question|how to|what is|why)' 44 | 45 | # 添加 Windows 标签 46 | Windows: 47 | - '(Windows|微软|Win10|Win11|Windows系统|微软系统|win)' 48 | -------------------------------------------------------------------------------- /.github/workflows/issue-labeler.yml: -------------------------------------------------------------------------------- 1 | name: add labels to Issues 2 | 3 | on: 4 | issues: 5 | types: [opened, edited] 6 | 7 | 8 | jobs: 9 | label_issues: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | issues: write 13 | contents: read 14 | steps: 15 | - name: get_last_run_time 16 | id: last_run 17 | run: | 18 | # 获取当前日期减去 1 天作为默认值(处理最近一天的 issues) 19 | echo "date=$(date -d '1 day ago' -u +"%Y-%m-%dT%H:%M:%SZ")" >> $GITHUB_OUTPUT 20 | 21 | - name: RegEx Issue Labeler 22 | uses: github/issue-labeler@v3.4 23 | with: 24 | include-title: 1 25 | repo-token: "${{ secrets.GITHUB_TOKEN }}" 26 | configuration-path: .github/issue-labeler.yml 27 | enable-versioned-regex: 0 28 | not-before: ${{ steps.last_run.outputs.date }} 29 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | weclone_archive-my/ 3 | **/pycache/ 4 | events.out.tfevents.* 5 | 归档/ 6 | *.pt 7 | *.npz 8 | *nohup.out 9 | *log.txt 10 | *cookie.bin 11 | *.gradio/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | 144 | *.zip 145 | LLaMA-Factory 146 | chatglm3-6b 147 | cache 148 | archive 149 | model_output* 150 | data/test 151 | .vscode 152 | *-my*.* 153 | *.csv 154 | *test.* 155 | *users.json 156 | Spark-TTS-0.5B/ 157 | uv.lock 158 | output* 159 | *.out 160 | 161 | Qwen*/ 162 | settings.jsonc 163 | settings.json 164 | dataset/blocked_words.json 165 | dataset/wechat/* 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU AFFERO GENERAL PUBLIC LICENSE 2 | Version 3, 19 November 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU Affero General Public License is a free, copyleft license for 11 | software and other kinds of works, specifically designed to ensure 12 | cooperation with the community in the case of network server software. 13 | 14 | The licenses for most software and other practical works are designed 15 | to take away your freedom to share and change the works. By contrast, 16 | our General Public Licenses are intended to guarantee your freedom to 17 | share and change all versions of a program--to make sure it remains free 18 | software for all its users. 19 | 20 | When we speak of free software, we are referring to freedom, not 21 | price. Our General Public Licenses are designed to make sure that you 22 | have the freedom to distribute copies of free software (and charge for 23 | them if you wish), that you receive source code or can get it if you 24 | want it, that you can change the software or use pieces of it in new 25 | free programs, and that you know you can do these things. 26 | 27 | Developers that use our General Public Licenses protect your rights 28 | with two steps: (1) assert copyright on the software, and (2) offer 29 | you this License which gives you legal permission to copy, distribute 30 | and/or modify the software. 31 | 32 | A secondary benefit of defending all users' freedom is that 33 | improvements made in alternate versions of the program, if they 34 | receive widespread use, become available for other developers to 35 | incorporate. Many developers of free software are heartened and 36 | encouraged by the resulting cooperation. However, in the case of 37 | software used on network servers, this result may fail to come about. 38 | The GNU General Public License permits making a modified version and 39 | letting the public access it on a server without ever releasing its 40 | source code to the public. 41 | 42 | The GNU Affero General Public License is designed specifically to 43 | ensure that, in such cases, the modified source code becomes available 44 | to the community. It requires the operator of a network server to 45 | provide the source code of the modified version running there to the 46 | users of that server. Therefore, public use of a modified version, on 47 | a publicly accessible server, gives the public access to the source 48 | code of the modified version. 49 | 50 | An older license, called the Affero General Public License and 51 | published by Affero, was designed to accomplish similar goals. This is 52 | a different license, not a version of the Affero GPL, but Affero has 53 | released a new version of the Affero GPL which permits relicensing under 54 | this license. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | TERMS AND CONDITIONS 60 | 61 | 0. Definitions. 62 | 63 | "This License" refers to version 3 of the GNU Affero General Public License. 64 | 65 | "Copyright" also means copyright-like laws that apply to other kinds of 66 | works, such as semiconductor masks. 67 | 68 | "The Program" refers to any copyrightable work licensed under this 69 | License. Each licensee is addressed as "you". "Licensees" and 70 | "recipients" may be individuals or organizations. 71 | 72 | To "modify" a work means to copy from or adapt all or part of the work 73 | in a fashion requiring copyright permission, other than the making of an 74 | exact copy. The resulting work is called a "modified version" of the 75 | earlier work or a work "based on" the earlier work. 76 | 77 | A "covered work" means either the unmodified Program or a work based 78 | on the Program. 79 | 80 | To "propagate" a work means to do anything with it that, without 81 | permission, would make you directly or secondarily liable for 82 | infringement under applicable copyright law, except executing it on a 83 | computer or modifying a private copy. Propagation includes copying, 84 | distribution (with or without modification), making available to the 85 | public, and in some countries other activities as well. 86 | 87 | To "convey" a work means any kind of propagation that enables other 88 | parties to make or receive copies. Mere interaction with a user through 89 | a computer network, with no transfer of a copy, is not conveying. 90 | 91 | An interactive user interface displays "Appropriate Legal Notices" 92 | to the extent that it includes a convenient and prominently visible 93 | feature that (1) displays an appropriate copyright notice, and (2) 94 | tells the user that there is no warranty for the work (except to the 95 | extent that warranties are provided), that licensees may convey the 96 | work under this License, and how to view a copy of this License. If 97 | the interface presents a list of user commands or options, such as a 98 | menu, a prominent item in the list meets this criterion. 99 | 100 | 1. Source Code. 101 | 102 | The "source code" for a work means the preferred form of the work 103 | for making modifications to it. "Object code" means any non-source 104 | form of a work. 105 | 106 | A "Standard Interface" means an interface that either is an official 107 | standard defined by a recognized standards body, or, in the case of 108 | interfaces specified for a particular programming language, one that 109 | is widely used among developers working in that language. 110 | 111 | The "System Libraries" of an executable work include anything, other 112 | than the work as a whole, that (a) is included in the normal form of 113 | packaging a Major Component, but which is not part of that Major 114 | Component, and (b) serves only to enable use of the work with that 115 | Major Component, or to implement a Standard Interface for which an 116 | implementation is available to the public in source code form. A 117 | "Major Component", in this context, means a major essential component 118 | (kernel, window system, and so on) of the specific operating system 119 | (if any) on which the executable work runs, or a compiler used to 120 | produce the work, or an object code interpreter used to run it. 121 | 122 | The "Corresponding Source" for a work in object code form means all 123 | the source code needed to generate, install, and (for an executable 124 | work) run the object code and to modify the work, including scripts to 125 | control those activities. However, it does not include the work's 126 | System Libraries, or general-purpose tools or generally available free 127 | programs which are used unmodified in performing those activities but 128 | which are not part of the work. For example, Corresponding Source 129 | includes interface definition files associated with source files for 130 | the work, and the source code for shared libraries and dynamically 131 | linked subprograms that the work is specifically designed to require, 132 | such as by intimate data communication or control flow between those 133 | subprograms and other parts of the work. 134 | 135 | The Corresponding Source need not include anything that users 136 | can regenerate automatically from other parts of the Corresponding 137 | Source. 138 | 139 | The Corresponding Source for a work in source code form is that 140 | same work. 141 | 142 | 2. Basic Permissions. 143 | 144 | All rights granted under this License are granted for the term of 145 | copyright on the Program, and are irrevocable provided the stated 146 | conditions are met. This License explicitly affirms your unlimited 147 | permission to run the unmodified Program. The output from running a 148 | covered work is covered by this License only if the output, given its 149 | content, constitutes a covered work. This License acknowledges your 150 | rights of fair use or other equivalent, as provided by copyright law. 151 | 152 | You may make, run and propagate covered works that you do not 153 | convey, without conditions so long as your license otherwise remains 154 | in force. You may convey covered works to others for the sole purpose 155 | of having them make modifications exclusively for you, or provide you 156 | with facilities for running those works, provided that you comply with 157 | the terms of this License in conveying all material for which you do 158 | not control copyright. Those thus making or running the covered works 159 | for you must do so exclusively on your behalf, under your direction 160 | and control, on terms that prohibit them from making any copies of 161 | your copyrighted material outside their relationship with you. 162 | 163 | Conveying under any other circumstances is permitted solely under 164 | the conditions stated below. Sublicensing is not allowed; section 10 165 | makes it unnecessary. 166 | 167 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 168 | 169 | No covered work shall be deemed part of an effective technological 170 | measure under any applicable law fulfilling obligations under article 171 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 172 | similar laws prohibiting or restricting circumvention of such 173 | measures. 174 | 175 | When you convey a covered work, you waive any legal power to forbid 176 | circumvention of technological measures to the extent such circumvention 177 | is effected by exercising rights under this License with respect to 178 | the covered work, and you disclaim any intention to limit operation or 179 | modification of the work as a means of enforcing, against the work's 180 | users, your or third parties' legal rights to forbid circumvention of 181 | technological measures. 182 | 183 | 4. Conveying Verbatim Copies. 184 | 185 | You may convey verbatim copies of the Program's source code as you 186 | receive it, in any medium, provided that you conspicuously and 187 | appropriately publish on each copy an appropriate copyright notice; 188 | keep intact all notices stating that this License and any 189 | non-permissive terms added in accord with section 7 apply to the code; 190 | keep intact all notices of the absence of any warranty; and give all 191 | recipients a copy of this License along with the Program. 192 | 193 | You may charge any price or no price for each copy that you convey, 194 | and you may offer support or warranty protection for a fee. 195 | 196 | 5. Conveying Modified Source Versions. 197 | 198 | You may convey a work based on the Program, or the modifications to 199 | produce it from the Program, in the form of source code under the 200 | terms of section 4, provided that you also meet all of these conditions: 201 | 202 | a) The work must carry prominent notices stating that you modified 203 | it, and giving a relevant date. 204 | 205 | b) The work must carry prominent notices stating that it is 206 | released under this License and any conditions added under section 207 | 7. This requirement modifies the requirement in section 4 to 208 | "keep intact all notices". 209 | 210 | c) You must license the entire work, as a whole, under this 211 | License to anyone who comes into possession of a copy. This 212 | License will therefore apply, along with any applicable section 7 213 | additional terms, to the whole of the work, and all its parts, 214 | regardless of how they are packaged. This License gives no 215 | permission to license the work in any other way, but it does not 216 | invalidate such permission if you have separately received it. 217 | 218 | d) If the work has interactive user interfaces, each must display 219 | Appropriate Legal Notices; however, if the Program has interactive 220 | interfaces that do not display Appropriate Legal Notices, your 221 | work need not make them do so. 222 | 223 | A compilation of a covered work with other separate and independent 224 | works, which are not by their nature extensions of the covered work, 225 | and which are not combined with it such as to form a larger program, 226 | in or on a volume of a storage or distribution medium, is called an 227 | "aggregate" if the compilation and its resulting copyright are not 228 | used to limit the access or legal rights of the compilation's users 229 | beyond what the individual works permit. Inclusion of a covered work 230 | in an aggregate does not cause this License to apply to the other 231 | parts of the aggregate. 232 | 233 | 6. Conveying Non-Source Forms. 234 | 235 | You may convey a covered work in object code form under the terms 236 | of sections 4 and 5, provided that you also convey the 237 | machine-readable Corresponding Source under the terms of this License, 238 | in one of these ways: 239 | 240 | a) Convey the object code in, or embodied in, a physical product 241 | (including a physical distribution medium), accompanied by the 242 | Corresponding Source fixed on a durable physical medium 243 | customarily used for software interchange. 244 | 245 | b) Convey the object code in, or embodied in, a physical product 246 | (including a physical distribution medium), accompanied by a 247 | written offer, valid for at least three years and valid for as 248 | long as you offer spare parts or customer support for that product 249 | model, to give anyone who possesses the object code either (1) a 250 | copy of the Corresponding Source for all the software in the 251 | product that is covered by this License, on a durable physical 252 | medium customarily used for software interchange, for a price no 253 | more than your reasonable cost of physically performing this 254 | conveying of source, or (2) access to copy the 255 | Corresponding Source from a network server at no charge. 256 | 257 | c) Convey individual copies of the object code with a copy of the 258 | written offer to provide the Corresponding Source. This 259 | alternative is allowed only occasionally and noncommercially, and 260 | only if you received the object code with such an offer, in accord 261 | with subsection 6b. 262 | 263 | d) Convey the object code by offering access from a designated 264 | place (gratis or for a charge), and offer equivalent access to the 265 | Corresponding Source in the same way through the same place at no 266 | further charge. You need not require recipients to copy the 267 | Corresponding Source along with the object code. If the place to 268 | copy the object code is a network server, the Corresponding Source 269 | may be on a different server (operated by you or a third party) 270 | that supports equivalent copying facilities, provided you maintain 271 | clear directions next to the object code saying where to find the 272 | Corresponding Source. Regardless of what server hosts the 273 | Corresponding Source, you remain obligated to ensure that it is 274 | available for as long as needed to satisfy these requirements. 275 | 276 | e) Convey the object code using peer-to-peer transmission, provided 277 | you inform other peers where the object code and Corresponding 278 | Source of the work are being offered to the general public at no 279 | charge under subsection 6d. 280 | 281 | A separable portion of the object code, whose source code is excluded 282 | from the Corresponding Source as a System Library, need not be 283 | included in conveying the object code work. 284 | 285 | A "User Product" is either (1) a "consumer product", which means any 286 | tangible personal property which is normally used for personal, family, 287 | or household purposes, or (2) anything designed or sold for incorporation 288 | into a dwelling. In determining whether a product is a consumer product, 289 | doubtful cases shall be resolved in favor of coverage. For a particular 290 | product received by a particular user, "normally used" refers to a 291 | typical or common use of that class of product, regardless of the status 292 | of the particular user or of the way in which the particular user 293 | actually uses, or expects or is expected to use, the product. A product 294 | is a consumer product regardless of whether the product has substantial 295 | commercial, industrial or non-consumer uses, unless such uses represent 296 | the only significant mode of use of the product. 297 | 298 | "Installation Information" for a User Product means any methods, 299 | procedures, authorization keys, or other information required to install 300 | and execute modified versions of a covered work in that User Product from 301 | a modified version of its Corresponding Source. The information must 302 | suffice to ensure that the continued functioning of the modified object 303 | code is in no case prevented or interfered with solely because 304 | modification has been made. 305 | 306 | If you convey an object code work under this section in, or with, or 307 | specifically for use in, a User Product, and the conveying occurs as 308 | part of a transaction in which the right of possession and use of the 309 | User Product is transferred to the recipient in perpetuity or for a 310 | fixed term (regardless of how the transaction is characterized), the 311 | Corresponding Source conveyed under this section must be accompanied 312 | by the Installation Information. But this requirement does not apply 313 | if neither you nor any third party retains the ability to install 314 | modified object code on the User Product (for example, the work has 315 | been installed in ROM). 316 | 317 | The requirement to provide Installation Information does not include a 318 | requirement to continue to provide support service, warranty, or updates 319 | for a work that has been modified or installed by the recipient, or for 320 | the User Product in which it has been modified or installed. Access to a 321 | network may be denied when the modification itself materially and 322 | adversely affects the operation of the network or violates the rules and 323 | protocols for communication across the network. 324 | 325 | Corresponding Source conveyed, and Installation Information provided, 326 | in accord with this section must be in a format that is publicly 327 | documented (and with an implementation available to the public in 328 | source code form), and must require no special password or key for 329 | unpacking, reading or copying. 330 | 331 | 7. Additional Terms. 332 | 333 | "Additional permissions" are terms that supplement the terms of this 334 | License by making exceptions from one or more of its conditions. 335 | Additional permissions that are applicable to the entire Program shall 336 | be treated as though they were included in this License, to the extent 337 | that they are valid under applicable law. If additional permissions 338 | apply only to part of the Program, that part may be used separately 339 | under those permissions, but the entire Program remains governed by 340 | this License without regard to the additional permissions. 341 | 342 | When you convey a copy of a covered work, you may at your option 343 | remove any additional permissions from that copy, or from any part of 344 | it. (Additional permissions may be written to require their own 345 | removal in certain cases when you modify the work.) You may place 346 | additional permissions on material, added by you to a covered work, 347 | for which you have or can give appropriate copyright permission. 348 | 349 | Notwithstanding any other provision of this License, for material you 350 | add to a covered work, you may (if authorized by the copyright holders of 351 | that material) supplement the terms of this License with terms: 352 | 353 | a) Disclaiming warranty or limiting liability differently from the 354 | terms of sections 15 and 16 of this License; or 355 | 356 | b) Requiring preservation of specified reasonable legal notices or 357 | author attributions in that material or in the Appropriate Legal 358 | Notices displayed by works containing it; or 359 | 360 | c) Prohibiting misrepresentation of the origin of that material, or 361 | requiring that modified versions of such material be marked in 362 | reasonable ways as different from the original version; or 363 | 364 | d) Limiting the use for publicity purposes of names of licensors or 365 | authors of the material; or 366 | 367 | e) Declining to grant rights under trademark law for use of some 368 | trade names, trademarks, or service marks; or 369 | 370 | f) Requiring indemnification of licensors and authors of that 371 | material by anyone who conveys the material (or modified versions of 372 | it) with contractual assumptions of liability to the recipient, for 373 | any liability that these contractual assumptions directly impose on 374 | those licensors and authors. 375 | 376 | All other non-permissive additional terms are considered "further 377 | restrictions" within the meaning of section 10. If the Program as you 378 | received it, or any part of it, contains a notice stating that it is 379 | governed by this License along with a term that is a further 380 | restriction, you may remove that term. If a license document contains 381 | a further restriction but permits relicensing or conveying under this 382 | License, you may add to a covered work material governed by the terms 383 | of that license document, provided that the further restriction does 384 | not survive such relicensing or conveying. 385 | 386 | If you add terms to a covered work in accord with this section, you 387 | must place, in the relevant source files, a statement of the 388 | additional terms that apply to those files, or a notice indicating 389 | where to find the applicable terms. 390 | 391 | Additional terms, permissive or non-permissive, may be stated in the 392 | form of a separately written license, or stated as exceptions; 393 | the above requirements apply either way. 394 | 395 | 8. Termination. 396 | 397 | You may not propagate or modify a covered work except as expressly 398 | provided under this License. Any attempt otherwise to propagate or 399 | modify it is void, and will automatically terminate your rights under 400 | this License (including any patent licenses granted under the third 401 | paragraph of section 11). 402 | 403 | However, if you cease all violation of this License, then your 404 | license from a particular copyright holder is reinstated (a) 405 | provisionally, unless and until the copyright holder explicitly and 406 | finally terminates your license, and (b) permanently, if the copyright 407 | holder fails to notify you of the violation by some reasonable means 408 | prior to 60 days after the cessation. 409 | 410 | Moreover, your license from a particular copyright holder is 411 | reinstated permanently if the copyright holder notifies you of the 412 | violation by some reasonable means, this is the first time you have 413 | received notice of violation of this License (for any work) from that 414 | copyright holder, and you cure the violation prior to 30 days after 415 | your receipt of the notice. 416 | 417 | Termination of your rights under this section does not terminate the 418 | licenses of parties who have received copies or rights from you under 419 | this License. If your rights have been terminated and not permanently 420 | reinstated, you do not qualify to receive new licenses for the same 421 | material under section 10. 422 | 423 | 9. Acceptance Not Required for Having Copies. 424 | 425 | You are not required to accept this License in order to receive or 426 | run a copy of the Program. Ancillary propagation of a covered work 427 | occurring solely as a consequence of using peer-to-peer transmission 428 | to receive a copy likewise does not require acceptance. However, 429 | nothing other than this License grants you permission to propagate or 430 | modify any covered work. These actions infringe copyright if you do 431 | not accept this License. Therefore, by modifying or propagating a 432 | covered work, you indicate your acceptance of this License to do so. 433 | 434 | 10. Automatic Licensing of Downstream Recipients. 435 | 436 | Each time you convey a covered work, the recipient automatically 437 | receives a license from the original licensors, to run, modify and 438 | propagate that work, subject to this License. You are not responsible 439 | for enforcing compliance by third parties with this License. 440 | 441 | An "entity transaction" is a transaction transferring control of an 442 | organization, or substantially all assets of one, or subdividing an 443 | organization, or merging organizations. If propagation of a covered 444 | work results from an entity transaction, each party to that 445 | transaction who receives a copy of the work also receives whatever 446 | licenses to the work the party's predecessor in interest had or could 447 | give under the previous paragraph, plus a right to possession of the 448 | Corresponding Source of the work from the predecessor in interest, if 449 | the predecessor has it or can get it with reasonable efforts. 450 | 451 | You may not impose any further restrictions on the exercise of the 452 | rights granted or affirmed under this License. For example, you may 453 | not impose a license fee, royalty, or other charge for exercise of 454 | rights granted under this License, and you may not initiate litigation 455 | (including a cross-claim or counterclaim in a lawsuit) alleging that 456 | any patent claim is infringed by making, using, selling, offering for 457 | sale, or importing the Program or any portion of it. 458 | 459 | 11. Patents. 460 | 461 | A "contributor" is a copyright holder who authorizes use under this 462 | License of the Program or a work on which the Program is based. The 463 | work thus licensed is called the contributor's "contributor version". 464 | 465 | A contributor's "essential patent claims" are all patent claims 466 | owned or controlled by the contributor, whether already acquired or 467 | hereafter acquired, that would be infringed by some manner, permitted 468 | by this License, of making, using, or selling its contributor version, 469 | but do not include claims that would be infringed only as a 470 | consequence of further modification of the contributor version. For 471 | purposes of this definition, "control" includes the right to grant 472 | patent sublicenses in a manner consistent with the requirements of 473 | this License. 474 | 475 | Each contributor grants you a non-exclusive, worldwide, royalty-free 476 | patent license under the contributor's essential patent claims, to 477 | make, use, sell, offer for sale, import and otherwise run, modify and 478 | propagate the contents of its contributor version. 479 | 480 | In the following three paragraphs, a "patent license" is any express 481 | agreement or commitment, however denominated, not to enforce a patent 482 | (such as an express permission to practice a patent or covenant not to 483 | sue for patent infringement). To "grant" such a patent license to a 484 | party means to make such an agreement or commitment not to enforce a 485 | patent against the party. 486 | 487 | If you convey a covered work, knowingly relying on a patent license, 488 | and the Corresponding Source of the work is not available for anyone 489 | to copy, free of charge and under the terms of this License, through a 490 | publicly available network server or other readily accessible means, 491 | then you must either (1) cause the Corresponding Source to be so 492 | available, or (2) arrange to deprive yourself of the benefit of the 493 | patent license for this particular work, or (3) arrange, in a manner 494 | consistent with the requirements of this License, to extend the patent 495 | license to downstream recipients. "Knowingly relying" means you have 496 | actual knowledge that, but for the patent license, your conveying the 497 | covered work in a country, or your recipient's use of the covered work 498 | in a country, would infringe one or more identifiable patents in that 499 | country that you have reason to believe are valid. 500 | 501 | If, pursuant to or in connection with a single transaction or 502 | arrangement, you convey, or propagate by procuring conveyance of, a 503 | covered work, and grant a patent license to some of the parties 504 | receiving the covered work authorizing them to use, propagate, modify 505 | or convey a specific copy of the covered work, then the patent license 506 | you grant is automatically extended to all recipients of the covered 507 | work and works based on it. 508 | 509 | A patent license is "discriminatory" if it does not include within 510 | the scope of its coverage, prohibits the exercise of, or is 511 | conditioned on the non-exercise of one or more of the rights that are 512 | specifically granted under this License. You may not convey a covered 513 | work if you are a party to an arrangement with a third party that is 514 | in the business of distributing software, under which you make payment 515 | to the third party based on the extent of your activity of conveying 516 | the work, and under which the third party grants, to any of the 517 | parties who would receive the covered work from you, a discriminatory 518 | patent license (a) in connection with copies of the covered work 519 | conveyed by you (or copies made from those copies), or (b) primarily 520 | for and in connection with specific products or compilations that 521 | contain the covered work, unless you entered into that arrangement, 522 | or that patent license was granted, prior to 28 March 2007. 523 | 524 | Nothing in this License shall be construed as excluding or limiting 525 | any implied license or other defenses to infringement that may 526 | otherwise be available to you under applicable patent law. 527 | 528 | 12. No Surrender of Others' Freedom. 529 | 530 | If conditions are imposed on you (whether by court order, agreement or 531 | otherwise) that contradict the conditions of this License, they do not 532 | excuse you from the conditions of this License. If you cannot convey a 533 | covered work so as to satisfy simultaneously your obligations under this 534 | License and any other pertinent obligations, then as a consequence you may 535 | not convey it at all. For example, if you agree to terms that obligate you 536 | to collect a royalty for further conveying from those to whom you convey 537 | the Program, the only way you could satisfy both those terms and this 538 | License would be to refrain entirely from conveying the Program. 539 | 540 | 13. Remote Network Interaction; Use with the GNU General Public License. 541 | 542 | Notwithstanding any other provision of this License, if you modify the 543 | Program, your modified version must prominently offer all users 544 | interacting with it remotely through a computer network (if your version 545 | supports such interaction) an opportunity to receive the Corresponding 546 | Source of your version by providing access to the Corresponding Source 547 | from a network server at no charge, through some standard or customary 548 | means of facilitating copying of software. This Corresponding Source 549 | shall include the Corresponding Source for any work covered by version 3 550 | of the GNU General Public License that is incorporated pursuant to the 551 | following paragraph. 552 | 553 | Notwithstanding any other provision of this License, you have 554 | permission to link or combine any covered work with a work licensed 555 | under version 3 of the GNU General Public License into a single 556 | combined work, and to convey the resulting work. The terms of this 557 | License will continue to apply to the part which is the covered work, 558 | but the work with which it is combined will remain governed by version 559 | 3 of the GNU General Public License. 560 | 561 | 14. Revised Versions of this License. 562 | 563 | The Free Software Foundation may publish revised and/or new versions of 564 | the GNU Affero General Public License from time to time. Such new versions 565 | will be similar in spirit to the present version, but may differ in detail to 566 | address new problems or concerns. 567 | 568 | Each version is given a distinguishing version number. If the 569 | Program specifies that a certain numbered version of the GNU Affero General 570 | Public License "or any later version" applies to it, you have the 571 | option of following the terms and conditions either of that numbered 572 | version or of any later version published by the Free Software 573 | Foundation. If the Program does not specify a version number of the 574 | GNU Affero General Public License, you may choose any version ever published 575 | by the Free Software Foundation. 576 | 577 | If the Program specifies that a proxy can decide which future 578 | versions of the GNU Affero General Public License can be used, that proxy's 579 | public statement of acceptance of a version permanently authorizes you 580 | to choose that version for the Program. 581 | 582 | Later license versions may give you additional or different 583 | permissions. However, no additional obligations are imposed on any 584 | author or copyright holder as a result of your choosing to follow a 585 | later version. 586 | 587 | 15. Disclaimer of Warranty. 588 | 589 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 590 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 591 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 592 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 593 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 594 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 595 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 596 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 597 | 598 | 16. Limitation of Liability. 599 | 600 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 601 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 602 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 603 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 604 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 605 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 606 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 607 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 608 | SUCH DAMAGES. 609 | 610 | 17. Interpretation of Sections 15 and 16. 611 | 612 | If the disclaimer of warranty and limitation of liability provided 613 | above cannot be given local legal effect according to their terms, 614 | reviewing courts shall apply local law that most closely approximates 615 | an absolute waiver of all civil liability in connection with the 616 | Program, unless a warranty or assumption of liability accompanies a 617 | copy of the Program in return for a fee. 618 | 619 | END OF TERMS AND CONDITIONS 620 | 621 | How to Apply These Terms to Your New Programs 622 | 623 | If you develop a new program, and you want it to be of the greatest 624 | possible use to the public, the best way to achieve this is to make it 625 | free software which everyone can redistribute and change under these terms. 626 | 627 | To do so, attach the following notices to the program. It is safest 628 | to attach them to the start of each source file to most effectively 629 | state the exclusion of warranty; and each file should have at least 630 | the "copyright" line and a pointer to where the full notice is found. 631 | 632 | 633 | Copyright (C) 634 | 635 | This program is free software: you can redistribute it and/or modify 636 | it under the terms of the GNU Affero General Public License as published 637 | by the Free Software Foundation, either version 3 of the License, or 638 | (at your option) any later version. 639 | 640 | This program is distributed in the hope that it will be useful, 641 | but WITHOUT ANY WARRANTY; without even the implied warranty of 642 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 643 | GNU Affero General Public License for more details. 644 | 645 | You should have received a copy of the GNU Affero General Public License 646 | along with this program. If not, see . 647 | 648 | Also add information on how to contact you by electronic and paper mail. 649 | 650 | If your software can interact with users remotely through a computer 651 | network, you should also make sure that it provides a way for users to 652 | get its source. For example, if your program is a web application, its 653 | interface could display a "Source" link that leads users to an archive 654 | of the code. There are many ways you could offer source, and different 655 | solutions will be better for different programs; see section 13 for the 656 | specific requirements. 657 | 658 | You should also get your employer (if you work as a programmer) or school, 659 | if any, to sign a "copyright disclaimer" for the program, if necessary. 660 | For more information on this, and how to apply and follow the GNU AGPL, see 661 | . 662 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![download](https://github.com/user-attachments/assets/5842e84e-004f-4afd-9373-af64e9575b78) 2 |

🚀 One-stop solution for creating your digital avatar from chat history 💡

3 |

🚀从聊天记录创造数字分身的一站式解决方案💡

4 | 5 | 6 |
7 | 8 | [![GitHub stars](https://img.shields.io/github/stars/xming521/WeClone?style=for-the-badge&logo=github&label=Stars&logoColor=white&color=ffda65)](https://github.com/xming521/WeClone/stargazers) 9 | [![GitHub release](https://img.shields.io/github/v/release/xming521/WeClone?style=for-the-badge&logo=github&label=Release&logoColor=white&color=06d094)](https://github.com/xming521/WeClone/releases) 10 | 11 | WeClone① 12 | 13 | [![Twitter](https://img.shields.io/badge/Twitter-@weclone567-000000?style=for-the-badge&logo=x&logoColor=white)](https://x.com/weclone567) 14 | [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/+JEdak4m0XEQ3NGNl) 15 | 16 | Featured|HelloGitHub 17 | xming521%2FWeClone | Trendshift 18 | Ask DeepWiki 19 |
20 | 21 |

22 | 项目主页 | 23 | 项目文档 | 24 | Windows部署指南 | 25 | Linux部署指南【保姆级】 26 |

27 | 28 | > [!IMPORTANT] 29 | >

WhatsApp and Telegram chat logs integration for digital avatar creation is coming !

30 | 31 | ## ✨核心功能 32 | - 💫 涵盖打造数字分身的全链路方案,包括聊天数据导出、预处理、模型训练、部署 33 | - 💬 使用微信聊天记录微调LLM,让大模型有"那味儿" 34 | - 🔗 绑定到微信、QQ、Telegram、企微、飞书机器人,实现自己的数字分身 35 | - 🛡️ 隐私信息过滤,本地化微调部署,数据安全可控 36 | 37 | ## 📋特性与说明 38 | 39 | > [!IMPORTANT] 40 | > - WeClone仍在快速迭代期,当前效果不代表最终效果。 41 | > - 微调LLM效果很大程度取决于模型大小、聊天数据的数量和质量,理论上模型越大,数据越多,效果越好。 42 | > - Windows环境未进行严格测试,可以使用WSL作为运行环境。详细教程可点击[Windows部署指南](https://blog.051088.xyz/2025/05/14/WeClone-%E7%94%A8%E5%BE%AE%E4%BF%A1%E8%81%8A%E5%A4%A9%E8%AE%B0%E5%BD%95%E6%89%93%E9%80%A0%E8%87%AA%E5%B7%B1%E7%9A%84AI%E6%95%B0%E5%AD%97%E5%88%86%E8%BA%AB/)查看。 43 | 44 | ### 硬件要求 45 | 46 | 项目默认使用Qwen2.5-7B-Instruct模型,LoRA方法对sft阶段微调,大约需要16GB显存。也可以使用[LLaMA Factory](https://github.com/hiyouga/LLaMA-Factory/blob/main/README_zh.md#%E6%A8%A1%E5%9E%8B)支持的其他模型和方法。 47 | 48 | 需要显存的估算值: 49 | | 方法 | 精度 | 7B | 14B | 30B | 70B | `x`B | 50 | | ------------------------------- | ---- | ----- | ----- | ----- | ------ | ------- | 51 | | Full (`bf16` or `fp16`) | 32 | 120GB | 240GB | 600GB | 1200GB | `18x`GB | 52 | | Full (`pure_bf16`) | 16 | 60GB | 120GB | 300GB | 600GB | `8x`GB | 53 | | Freeze/LoRA/GaLore/APOLLO/BAdam | 16 | 16GB | 32GB | 64GB | 160GB | `2x`GB | 54 | | QLoRA | 8 | 10GB | 20GB | 40GB | 80GB | `x`GB | 55 | | QLoRA | 4 | 6GB | 12GB | 24GB | 48GB | `x/2`GB | 56 | | QLoRA | 2 | 4GB | 8GB | 16GB | 24GB | `x/4`GB | 57 | 58 | 59 | ## 环境搭建 60 | 1.cuda安装(已安装可跳过,**要求版本12.4及以上**):[LLaMA Factory](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html#cuda) 61 | 62 | 2.建议使用 [uv](https://docs.astral.sh/uv/)安装依赖,这是一个非常快速的 Python 环境管理器。安装uv后,您可以使用以下命令创建一个新的Python环境并安装依赖项,注意这不包含音频克隆功能的依赖: 63 | ```bash 64 | git clone https://github.com/xming521/WeClone.git 65 | cd WeClone 66 | uv venv .venv --python=3.10 67 | source .venv/bin/activate # windows下执行 .venv\Scripts\activate 68 | uv pip install --group main -e . 69 | ``` 70 | > [!TIP] 71 | > 如果要使用最新的模型进行微调,需要手动安装最新版LLaMA Factory:`uv pip install --upgrade git+https://github.com/hiyouga/LLaMA-Factory.git`,同时其他依赖版本也可能需要修改,例如vllm pytorch transforms 72 | 73 | 3.将配置文件模板复制一份并重命名为`settings.jsonc`,后续配置修改在此文件进行: 74 | ```bash 75 | cp settings.template.jsonc settings.jsonc 76 | ``` 77 | > [!NOTE] 78 | > 训练以及推理相关配置统一在文件`settings.jsonc` 79 | 80 | 4.使用以下命令测试CUDA环境是否正确配置并可被PyTorch识别,Mac不需要: 81 | ```bash 82 | python -c "import torch; print('CUDA是否可用:', torch.cuda.is_available());" 83 | ``` 84 | 85 | 5.(可选)安装FlashAttention,加速训练和推理:`uv pip install flash-attn --no-build-isolation` 86 | 87 | ## 模型下载 88 | ```bash 89 | git lfs install 90 | git clone https://www.modelscope.cn/Qwen/Qwen2.5-7B-Instruct.git 91 | ``` 92 | 下载有问题使用其他方式下载:[模型的下载](https://www.modelscope.cn/docs/models/download) 93 | 94 | 95 | ## 数据准备 96 | 97 | 请使用[PyWxDump](https://github.com/xaoyaoo/PyWxDump)提取微信聊天记录(不支持4.0版本微信)。可以先将手机的聊天记录迁移(备份)到电脑,数据量更多一些。下载软件并解密数据库后,点击聊天备份,导出类型为CSV,可以导出多个联系人(不建议使用群聊记录),然后将导出的位于`wxdump_tmp/export` 的 `csv` 文件夹放在`./dataset`目录即可,也就是不同人聊天记录的文件夹一起放在 `./dataset/csv`。 98 | 99 | ## 数据预处理 100 | 101 | - 项目默认去除了数据中的手机号、身份证号、邮箱、网址。还在`settings.jsonc`中提供了一个禁用词词库`blocked_words`,可以自行添加需要过滤的词句(会默认去掉包括禁用词的整句)。 102 | > [!IMPORTANT] 103 | > 🚨 请一定注意保护个人隐私,不要泄露个人信息! 104 | 105 | - 执行以下命令对数据进行处理,可以根据自己的聊天风格修改settings.jsonc的`make_dataset_args`。 106 | ```bash 107 | weclone-cli make-dataset 108 | ``` 109 | - 目前仅支持时间窗口策略,根据`single_combine_time_window`将单人连续消息通过逗号连接合并为一句,根据`qa_match_time_window`匹配问答对。 110 | - 可以启用`clean_dataset`中的`enable_clean`选项,对数据进行清洗,以达到更好效果。* 当前系统支持使用 `llm judge` 对聊天记录进行打分,提供 **vllm 离线推理** 和 **API 在线推理** 两种方式。可通过将 `settings.jsonc` 文件中的 `"online_llm_clear": false` 修改为 `true` 来启用 API 在线推理模式,并配置相应的 `base_url`、`llm_api_key`、`model_name` 等参数。所有兼容 OpenAI 接口的模型均可接入。 111 | - 在获得 `llm 打分分数分布情况` 后,可通过设置 `accept_score` 参数筛选可接受的分数区间,同时可适当降低 `train_sft_args` 中的 `lora_dropout` 参数,以提升模型的拟合效果。 112 | 113 | ## 配置参数并微调模型 114 | 115 | - (可选)修改 `settings.jsonc` 的 `model_name_or_path` 和 `template` 选择本地下载好的其他模型。 116 | - 修改`per_device_train_batch_size`以及`gradient_accumulation_steps`来调整显存占用。 117 | - 可以根据自己数据集的数量和质量修改`train_sft_args`的`num_train_epochs`、`lora_rank`、`lora_dropout`等参数。 118 | 119 | ### 单卡训练 120 | ```bash 121 | weclone-cli train-sft 122 | ``` 123 | 多卡环境单卡训练,需要先执行 `export CUDA_VISIBLE_DEVICES=0` 124 | 125 | ### 多卡训练 126 | 取消`settings.jsonc`中`deepspeed`行代码注释,使用以下命令多卡训练: 127 | ```bash 128 | uv pip install deepspeed 129 | deepspeed --num_gpus=使用显卡数量 weclone/train/train_sft.py 130 | ``` 131 | 132 | ### 使用浏览器demo简单推理 133 | 可以在这一步测试出合适的temperature、top_p值,修改settings.jsonc的`infer_args`后,供后续推理时使用。 134 | ```bash 135 | weclone-cli webchat-demo 136 | ``` 137 | 138 | ### 使用接口进行推理 139 | 140 | ```bash 141 | weclone-cli server 142 | ``` 143 | 144 | ### 使用常见聊天问题测试 145 | 不包含询问个人信息的问题,仅有日常聊天。测试结果在test_result-my.txt。 146 | ```bash 147 | weclone-cli server 148 | weclone-cli test-model 149 | ``` 150 | 151 | ## 🖼️ 微调效果 152 | 使用Qwen2.5-14B-Instruct模型,大概3万条处理后的有效数据,loss降到了3.5左右的效果。 153 |
154 | 截图 155 |
156 | alt text 157 | alt text 158 | alt text 159 | alt text 160 |
161 |
162 | 163 | 164 | ## 🤖 部署到聊天机器人 165 | 166 | ### AstrBot 167 | 168 | [AstrBot](https://github.com/AstrBotDevs/AstrBot) 是易上手的多平台 LLM 聊天机器人及开发框架 ✨ 平台支持 QQ、QQ频道、Telegram、微信、企微、飞书。 169 | 170 | 使用步骤: 171 | 1. 部署 AstrBot 172 | 2. 在 AstrBot 中部署消息平台 173 | 3. 执行 `weclone-cli server` 启动api服务 174 | 4. 在 AstrBot 中新增服务提供商,类型选择OpenAI,API Base URL 根据AstrBot部署方式填写(例如docker部署可能为http://172.17.0.1:8005/v1) ,模型填写gpt-3.5-turbo,API Key随意填写一个 175 | 5. 微调后不支持工具调用,请先关掉默认的工具,消息平台发送指令: `/tool off all`,否则会没有微调后的效果。 176 | 6. 根据微调时使用的default_system,在 AstrBot 中设置系统提示词。 177 | ![5](https://github.com/user-attachments/assets/19de7072-076a-4cdf-8ae6-46b9b89f536a) 178 | > [!IMPORTANT] 179 | > 检查api_service的日志,尽量保证大模型服务请求的参数和微调时一致,tool插件能力都关掉。 180 | 7. 调整采样参数,例如temperature、top_p、top_k等 181 | [配置自定义的模型参数](https://astrbot.app/config/model-config.html#%E9%85%8D%E7%BD%AE%E8%87%AA%E5%AE%9A%E4%B9%89%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%8F%82%E6%95%B0) 182 | 183 | ### LangBot 184 | 185 | [LangBot](https://github.com/RockChinQ/LangBot) 是一个开源的接入全球多种即时通信平台的 LLM 机器人平台,适合各种场景使用。 186 | 187 | 1. [部署 LangBot](https://github.com/RockChinQ/LangBot#-%E5%BC%80%E5%A7%8B%E4%BD%BF%E7%94%A8) 188 | 2. 在 LangBot 中添加一个机器人 189 | 4. 在模型页添加新模型,名称`gpt-3.5-turbo`,供应商选择 OpenAI,填写 请求 URL 为 WeClone 的地址,详细连接方式可以参考[文档](https://docs.langbot.app/zh/workshop/network-details.html),API Key 任意填写。 190 | 191 | image 192 | 193 | 6. 在流水线配置中选择刚才添加的模型,或修改提示词配置 194 | 195 | image 196 | 197 | ## 📌 路线图 198 | - [ ] 更丰富的上下文:包括上下文对话、聊天对象信息、时间等 + 思考 199 | - [ ] Memory 支持 200 | - [ ] 支持多模态 201 | - [ ] 数据增强 202 | - [ ] 支持GUI 203 | 204 | ## 问题解决 205 | - 微调问题:[LLaMA-Factory| FAQs | 常见问题](https://github.com/hiyouga/LLaMA-Factory/issues/4614) 或者更方便的 [![更方便的Ask DeepWiki](https://deepwiki.com/badge.svg)](https://deepwiki.com/hiyouga/LLaMA-Factory) 206 | 207 | ## ❤️ 贡献代码 208 | 209 | 欢迎任何 Issues/Pull Requests! 210 | 211 | 你可以通过查看Issues或帮助审核 PR(拉取请求)来贡献。对于新功能的添加,请先通过 Issue 讨论。 212 | 运行`uv pip install --group dev -e .`安装开发依赖。 213 | 项目使用`pytest`测试(测试脚本待完善),`pyright`检查类型,`ruff`检查代码格式。 214 | 215 | 216 | ## ⚠️ 免责声明 217 | > [!CAUTION] 218 | > 请勿用于非法用途,否则后果自负。 219 |
220 | 1. 使用目的 221 | 222 | * 本项目仅供学习交流使用,**请勿用于非法用途**,**请勿用于非法用途**,**请勿用于非法用途**,否则后果自负。 223 | * 用户理解并同意,任何违反法律法规、侵犯他人合法权益的行为,均与本项目及其开发者无关,后果由用户自行承担。 224 | 225 | 2. 使用期限 226 | 227 | * 您应该在下载保存使用本项目的24小时内,删除本项目的源代码和程序;超出此期限的任何使用行为,一概与本项目及其开发者无关。 228 | 229 | 3. 操作规范 230 | 231 | * 本项目仅允许在授权情况下使用数据训练,严禁用于非法目的,否则自行承担所有相关责任;用户如因违反此规定而引发的任何法律责任,将由用户自行承担,与本项目及其开发者无关。 232 | * 严禁用于窃取他人隐私,严禁用于窃取他人隐私,严禁用于窃取他人隐私,否则自行承担所有相关责任。 233 | 234 | 4. 免责声明接受 235 | 236 | * 下载、保存、进一步浏览源代码或者下载安装、编译使用本程序,表示你同意本警告,并承诺遵守它; 237 | 238 | 5. 禁止用于非法测试或渗透 239 | 240 | * 禁止利用本项目的相关技术从事非法测试或渗透,禁止利用本项目的相关代码或相关技术从事任何非法工作,如因此产生的一切不良后果与本项目及其开发者无关。 241 | * 任何因此产生的不良后果,包括但不限于数据泄露、系统瘫痪、侵犯隐私等,均与本项目及其开发者无关,责任由用户自行承担。 242 | 243 | 6. 免责声明修改 244 | 245 | * 本免责声明可能根据项目运行情况和法律法规的变化进行修改和调整。用户应定期查阅本页面以获取最新版本的免责声明,使用本项目时应遵守最新版本的免责声明。 246 | 247 | 7. 其他 248 | 249 | * 除本免责声明规定外,用户在使用本项目过程中应遵守相关的法律法规和道德规范。对于因用户违反相关规定而引发的任何纠纷或损失,本项目及其开发者不承担任何责任。 250 | 251 | * 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 252 | 253 |
254 | 请用户慎重阅读并理解本免责声明的所有内容,确保在使用本项目时严格遵守相关规定。 255 | 256 |
257 |
258 |
259 | 260 | ## ⭐ Star History 261 | > [!TIP] 262 | > 如果本项目对您有帮助,或者您关注本项目的未来发展,请给项目 Star,谢谢 263 | 264 |
265 | 266 | [![Star History Chart](https://api.star-history.com/svg?repos=xming521/WeClone&type=Date)](https://www.star-history.com/#xming521/WeClone&Date) 267 | 268 |
269 | 270 | 271 |
克隆我们,保留灵魂的芬芳
272 | -------------------------------------------------------------------------------- /dataset/res_csv/pt/dataset_info.json: -------------------------------------------------------------------------------- 1 | {"wechat-pt":{ 2 | "file_name": "./pt-my.json", 3 | "columns": { 4 | "prompt": "c" 5 | } 6 | }} -------------------------------------------------------------------------------- /dataset/res_csv/sft/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "wechat-sft": { 3 | "file_name": "./sft-my.json", 4 | "columns": { 5 | "prompt": "instruction", 6 | "response": "output", 7 | "system": "system" 8 | } 9 | }, 10 | "wechat-sft-with-history": { 11 | "file_name": "./sft-my.json", 12 | "columns": { 13 | "prompt": "instruction", 14 | "response": "output", 15 | "system": "system", 16 | "history": "history" 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /dataset/test_data-privacy.json: -------------------------------------------------------------------------------- 1 | { 2 | "questions": [ 3 | [ 4 | "你多大了?" 5 | ], 6 | [ 7 | "你有什么爱好吗?" 8 | ], 9 | [ 10 | "你的理想是什么?", 11 | "你觉得你离你的理想还有多远?" 12 | ], 13 | [ 14 | "你最近在忙什么?", 15 | "工作/学习顺利吗?", 16 | "有什么有趣的事情发生吗?" 17 | ], 18 | [ 19 | "你喜欢看什么类型的电影?", 20 | "最近看过什么好看的电影吗?", 21 | "你最喜欢的电影是什么?" 22 | ], 23 | [ 24 | "你平时喜欢听什么音乐?", 25 | "有推荐的歌手或乐队吗?", 26 | "最近有喜欢的歌曲吗?" 27 | ], 28 | [ 29 | "你喜欢旅游吗?", 30 | "去过哪些地方?", 31 | "最喜欢的旅游地是哪里?" 32 | ], 33 | [ 34 | "你喜欢读书吗?", 35 | "最近在读什么书?", 36 | "最喜欢的书是哪本?" 37 | ], 38 | [ 39 | "你平时喜欢运动吗?", 40 | "喜欢做哪些运动?", 41 | "有固定去锻炼吗?" 42 | ], 43 | [ 44 | "周末一般都做些什么?", 45 | "有没有什么特别的计划?", 46 | "周末喜欢宅在家还是出去玩?" 47 | ], 48 | [ 49 | "你喜欢宠物吗?", 50 | "有养宠物吗?", 51 | "最喜欢什么动物?" 52 | ], 53 | [ 54 | "你喜欢吃什么类型的食物?", 55 | "有推荐的餐厅吗?", 56 | "最喜欢的菜是什么?" 57 | ], 58 | [ 59 | "你喜欢什么样的天气?", 60 | "最喜欢的季节是哪一个?", 61 | "你觉得今天的天气怎么样?" 62 | ], 63 | [ 64 | "你有看电视剧的习惯吗?", 65 | "最近在追哪部剧?", 66 | "最喜欢的电视剧是哪部?" 67 | ], 68 | [ 69 | "你喜欢玩游戏吗?", 70 | "最近在玩什么游戏?", 71 | "有推荐的好玩的游戏吗?" 72 | ], 73 | [ 74 | "你会做饭吗?", 75 | "平时喜欢做哪些菜?", 76 | "有没有特别拿手的菜?" 77 | ], 78 | [ 79 | "你喜欢购物吗?", 80 | "最近买了什么新东西?", 81 | "有推荐的购物网站或店铺吗?" 82 | ], 83 | [ 84 | "你平时怎么放松自己?", 85 | "有特别的解压方式吗?", 86 | "最喜欢的放松活动是什么?" 87 | ], 88 | [ 89 | "你喜欢和朋友出去玩吗?", 90 | "平时会和朋友去哪玩?", 91 | "最近有没有和朋友聚会的计划?" 92 | ], 93 | [ 94 | "你喜欢喝咖啡还是茶?", 95 | "有没有特别喜欢的咖啡馆或茶馆?", 96 | "最喜欢的饮品是什么?" 97 | ], 98 | [ 99 | "你有兄弟姐妹吗?", 100 | "和他们关系怎么样?", 101 | "经常联系吗?" 102 | ], 103 | [ 104 | "你喜欢读什么类型的杂志?", 105 | "最近有看什么有趣的文章吗?", 106 | "有订阅的杂志吗?" 107 | ], 108 | [ 109 | "你喜欢看体育比赛吗?", 110 | "最喜欢的运动项目是什么?", 111 | "有没有特别支持的球队或运动员?" 112 | ], 113 | [ 114 | "你会说其他语言吗?", 115 | "最想学的语言是什么?", 116 | "学习语言有什么技巧吗?" 117 | ], 118 | [ 119 | "你对科技产品感兴趣吗?", 120 | "最近有没有关注什么新科技?", 121 | "最喜欢的电子产品是什么?" 122 | ], 123 | [ 124 | "你喜欢喝什么样的饮料?", 125 | "有没有自己调饮料的习惯?", 126 | "最喜欢的饮品品牌是什么?" 127 | ], 128 | [ 129 | "你平时用社交媒体吗?", 130 | "常用哪些平台?", 131 | "在社交媒体上做什么?" 132 | ], 133 | [ 134 | "你对艺术感兴趣吗?", 135 | "最喜欢的艺术家是谁?", 136 | "有去过哪些艺术展览?" 137 | ], 138 | [ 139 | "你喜欢DIY吗?", 140 | "平时做些什么手工?", 141 | "有没有完成的作品可以分享?" 142 | ], 143 | [ 144 | "你喜欢种植植物吗?", 145 | "有养什么植物?", 146 | "最喜欢的植物是什么?" 147 | ], 148 | [ 149 | "你喜欢拍照吗?", 150 | "喜欢拍什么样的照片?", 151 | "有没有用什么特别的摄影设备?" 152 | ], 153 | [ 154 | "你喜欢听播客吗?", 155 | "常听哪些主题的播客?", 156 | "有没有推荐的播客?" 157 | ], 158 | [ 159 | "你对历史感兴趣吗?", 160 | "最喜欢哪个历史时期?", 161 | "有没有特别喜欢的历史人物?" 162 | ], 163 | [ 164 | "你喜欢画画吗?", 165 | "平时画什么类型的画?", 166 | "有参加过画展吗?" 167 | ], 168 | [ 169 | "你喜欢写作吗?", 170 | "平时写什么类型的文章?", 171 | "有没有发表过作品?" 172 | ], 173 | [ 174 | "你喜欢钓鱼吗?", 175 | "平时去哪里钓鱼?", 176 | "有没有钓到过什么大鱼?" 177 | ], 178 | [ 179 | "你喜欢露营吗?", 180 | "平时会去哪里露营?", 181 | "有没有什么难忘的露营经历?" 182 | ], 183 | [ 184 | "你喜欢摄影吗?", 185 | "最喜欢拍什么题材?", 186 | "有没有特别喜欢的摄影师?" 187 | ], 188 | [ 189 | "你喜欢喝酒吗?", 190 | "喜欢什么类型的酒?", 191 | "有没有推荐的酒吧或品牌?" 192 | ], 193 | [ 194 | "你喜欢滑雪吗?", 195 | "平时去哪里滑雪?", 196 | "有没有什么滑雪技巧分享?" 197 | ], 198 | [ 199 | "你喜欢海边还是山里?", 200 | "最喜欢去哪个地方度假?", 201 | "有没有什么特别推荐的景点?" 202 | ], 203 | [ 204 | "你喜欢参加音乐节吗?", 205 | "参加过哪些音乐节?", 206 | "最喜欢的音乐节是哪一个?" 207 | ], 208 | [ 209 | "你喜欢跑步吗?", 210 | "平时跑多长距离?", 211 | "有没有参加过马拉松?" 212 | ], 213 | [ 214 | "你喜欢参加聚会吗?", 215 | "平时和朋友聚会做什么?", 216 | "有没有什么有趣的聚会游戏?" 217 | ], 218 | [ 219 | "你喜欢收集东西吗?", 220 | "收集什么类型的物品?", 221 | "有没有什么特别的收藏?" 222 | ] 223 | ] 224 | } -------------------------------------------------------------------------------- /dataset/test_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "questions": [ 3 | [ 4 | "吃了吗?", 5 | "吃的什么啊", 6 | "好吃吗", 7 | "多少钱啊", 8 | "可以请我吃吗" 9 | ], 10 | [ 11 | "干嘛呢?", 12 | "等会准备干什么去" 13 | ], 14 | [ 15 | "在忙什么呢?", 16 | "今天有什么特别的安排吗?", 17 | "感觉怎么样?" 18 | ], 19 | [ 20 | "最近有什么新鲜事发生吗?", 21 | "有没有什么有趣的故事可以分享?" 22 | ], 23 | [ 24 | "周末过得怎么样?", 25 | "做了什么好玩的?" 26 | ], 27 | [ 28 | "最近看了什么好看的电影或电视剧吗?", 29 | "有什么推荐的吗?", 30 | "大概讲了什么内容呀?" 31 | ], 32 | [ 33 | "今天天气怎么样?", 34 | "你那里呢?" 35 | ], 36 | [ 37 | "最近工作/学习顺利吗?", 38 | "有没有遇到什么挑战?" 39 | ], 40 | [ 41 | "嗨,这会儿在忙啥呢?", 42 | "今天有什么特别的安排不?", 43 | "一切都还顺利吧?" 44 | ], 45 | [ 46 | "你那边现在天气咋样啊?", 47 | "是大晴天还是有点阴沉沉的?", 48 | "冷不冷,或者热不热呀?" 49 | ], 50 | [ 51 | "到饭点儿了没呀?", 52 | "今天打算犒劳一下自己,吃点啥好吃的?", 53 | "有没有啥特别想吃的,或者想去哪家馆子尝尝鲜?" 54 | ], 55 | [ 56 | "最近网上有啥好玩儿的新闻或者梗吗?", 57 | "刷到啥有意思的视频或者段子没?分享一下呗!" 58 | ], 59 | [ 60 | "待会儿有啥打算呀?", 61 | "今天剩下的时间准备怎么过呢?" 62 | ], 63 | [ 64 | "今天有没有碰到啥让你眼前一亮的小事儿?", 65 | "随便聊聊呗,有啥轻松点的话题不?" 66 | ], 67 | [ 68 | "今天有啥新发现或者小感悟没?", 69 | "感觉今天过得快不快?节奏怎么样?" 70 | ], 71 | [ 72 | "你现在周围环境咋样,吵不吵?", 73 | "今天出门溜达了没,外面人多不多呀?", 74 | "瞅瞅窗外,有啥特别的景儿不?" 75 | ], 76 | [ 77 | "吃饭了没啊?", 78 | "吃的啥呀?合胃口不?" 79 | ], 80 | [ 81 | "今天怎么样啊?累不累?", 82 | "有啥事儿不?" 83 | ], 84 | [ 85 | "最近身体还好吧?", 86 | "没什么不舒服的地方吧?" 87 | ], 88 | [ 89 | "今天忙不忙啊?", 90 | "都干啥了呀?" 91 | ], 92 | [ 93 | "家里都挺好的吧?", 94 | "有啥需要帮忙的不?" 95 | ], 96 | [ 97 | "今天出门了没?", 98 | "外面冷不冷/热不热啊?多穿点/注意防暑。" 99 | ], 100 | [ 101 | "最近有啥开心的事儿不?说来听听!", 102 | "或者有啥烦心事儿,跟我说说?" 103 | ], 104 | [ 105 | "晚上早点休息啊,别熬太晚。", 106 | "睡得好不好啊最近?" 107 | ], 108 | [ 109 | "缺啥东西不?跟我说。", 110 | "钱够不够花呀?" 111 | ], 112 | [ 113 | "今天看到啥有意思的了没?", 114 | "或者有啥想跟我分享的?" 115 | ], 116 | [ 117 | "周末有啥安排啊?", 118 | "要不要一起吃个饭/出去转转?" 119 | ], 120 | [ 121 | "最近常联系的那些朋友都还好不?", 122 | "有空多聚聚。" 123 | ], 124 | [ 125 | "工作/学习上还顺利吧?", 126 | "别太给自己压力啊。" 127 | ], 128 | [ 129 | "今天做了啥好吃的呀?", 130 | "下次也给我尝尝呗!" 131 | ], 132 | [ 133 | "有啥新闻没有啊最近?", 134 | "跟我讲讲。" 135 | ], 136 | [ 137 | "那谁谁谁最近怎么样了?", 138 | "好久没听到他/她消息了。" 139 | ], 140 | [ 141 | "今天心情好不好呀?", 142 | "看你气色不错/有点疲惫。" 143 | ], 144 | [ 145 | "有啥想吃的没?下次给你做/带。", 146 | "或者想去哪儿玩,我陪你。" 147 | ], 148 | [ 149 | "最近有没有看啥电视剧/电影啊?", 150 | "有啥好看的推荐给我呗。" 151 | ], 152 | [ 153 | "没事儿就早点回家/休息。", 154 | "注意安全啊。" 155 | ] 156 | ] 157 | } -------------------------------------------------------------------------------- /ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "allgather_partitions": true, 16 | "allgather_bucket_size": 5e8, 17 | "overlap_comm": true, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | }, 22 | "gradient_accumulation_steps": "auto", 23 | "gradient_clipping": "auto", 24 | "steps_per_print": 2000, 25 | "train_batch_size": "auto", 26 | "train_micro_batch_size_per_gpu": "auto", 27 | "wall_clock_breakdown": false 28 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "WeClone" 3 | version = "0.2.21" 4 | description = "从聊天记录创造数字分身的一站式解决方案" 5 | authors = [{ name = "xming521" }] 6 | readme = "README.md" 7 | requires-python = ">=3.10,<3.11" 8 | 9 | dependencies = [ 10 | "pandas", 11 | "commentjson", 12 | "click", 13 | "pydantic==2.10.6", 14 | "setuptools>=78.1.0", 15 | "loguru>=0.7.3", 16 | "torch>=2.6.0", 17 | "transformers==4.49.0", 18 | "tomli; python_version < '3.11'", 19 | "langchain", 20 | ] 21 | 22 | [tool.weclone] 23 | # 配置文件的版本号,当配置文件结构或重要默认值发生变化时,应增加此版本号 24 | config_version = "0.2.21" 25 | 26 | # 配置文件更新日志 27 | config_changelog = """ 28 | [0.2.1] - 2025-04-29 - 初始配置版本。 29 | [0.2.2] - 2025-05-01 - 增加llm清洗数据配置,blocked_words迁移到settings.jsonc统一配置文件。 30 | [0.2.21] - 2025-05-01 - 增加在线llm清洗数据配置,兼容openai风格接口。 31 | """ 32 | 33 | [dependency-groups] 34 | # xcodec = ["xcodec2==0.1.3"] 35 | sparktts = [ 36 | "einops>=0.8.1", 37 | "einx>=0.3.0", 38 | "numpy==1.26.4", 39 | "omegaconf>=2.3.0", 40 | "packaging>=24.2", 41 | "safetensors>=0.5.2", 42 | "soundfile>=0.12.1", 43 | "soxr>=0.5.0.post1", 44 | "torchaudio>=2.6.0", 45 | "tqdm>=4.66.5", 46 | ] 47 | main = [ 48 | "llamafactory>=0.9.2", 49 | "openai==1.76.0", 50 | "vllm==0.8.2; platform_system == 'Linux'", 51 | ] 52 | dev = ["pytest", "pytest-order", "pyright", "ruff"] 53 | 54 | [project.scripts] 55 | weclone-cli = "weclone.cli:cli" 56 | 57 | [tool.uv] 58 | conflicts = [ 59 | # [{ group = "wx" }, { group = "xcodec" }], 60 | ] 61 | 62 | [tool.uv.sources] 63 | torch = [ 64 | { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, 65 | { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, 66 | ] 67 | torchaudio = [ 68 | { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, 69 | { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, 70 | ] 71 | torchvision = [ 72 | { index = "pytorch-cu124", marker = "platform_system == 'Windows'" }, 73 | { index = "pytorch-cu124", marker = "platform_system == 'Linux'" }, 74 | ] 75 | 76 | 77 | [[tool.uv.index]] 78 | url = "https://pypi.tuna.tsinghua.edu.cn/simple/" 79 | default = true 80 | 81 | [[tool.uv.index]] 82 | name = "pytorch-cu124" 83 | url = "https://download.pytorch.org/whl/cu124" 84 | explicit = true 85 | 86 | [tool.setuptools.packages.find] 87 | where = ["."] # 表示在项目根目录开始查找 88 | include = ["weclone*"] # 只包含名为 weclone 的目录及其子包 89 | exclude = ["*tests*", "*archive*"] # 可以选择性排除其他模式,比如测试目录 90 | 91 | 92 | [tool.pyright] 93 | typeCheckingMode = "basic" 94 | include = ["weclone/data"] 95 | exclude = ["**/archive", "**/tests"] 96 | ignore = ["**/archive"] 97 | 98 | reportMissingImports = "error" 99 | reportMissingTypeStubs = false 100 | 101 | pythonVersion = "3.10" 102 | pythonPlatform = "Linux" 103 | 104 | [tool.ruff] 105 | exclude = [ 106 | "**/archive", 107 | "**/tests", 108 | "weclone-audio/src/server未完工", 109 | "weclone-audio/src/Spark-TTS", 110 | ] 111 | line-length = 120 112 | 113 | lint.ignore = ["F403", "F405", "E501", "E402"] 114 | lint.select = [ 115 | "F", # Pyflakes 116 | "W", # pycodestyle warnings 117 | "E", # pycodestyle errors 118 | "ASYNC", # flake8-async 119 | "C4", # flake8-comprehensions 120 | "Q", # flake8-quotes 121 | ] 122 | target-version = "py310" 123 | 124 | [tool.pytest.ini_options] 125 | addopts = "-x" 126 | -------------------------------------------------------------------------------- /settings.template.jsonc: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.21", 3 | "common_args": { 4 | "model_name_or_path": "./Qwen2.5-7B-Instruct", 5 | "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir 6 | "template": "qwen", 7 | "default_system": "请你扮演一名人类,不要说自己是人工智能", 8 | "finetuning_type": "lora", 9 | "trust_remote_code": true 10 | }, 11 | "cli_args": { 12 | "full_log": false 13 | }, 14 | "make_dataset_args": { 15 | //数据处理配置 16 | "include_type": [ 17 | "text", 18 | // "image" 19 | ], 20 | "blocked_words": [ // 禁用词 21 | "例如 姓名", 22 | "例如 密码", 23 | "//....." 24 | ], 25 | "single_combine_strategy": "time_window", // 单人组成单句策略 26 | "qa_match_strategy": "time_window", // 组成qa策略 27 | "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟), 28 | "qa_match_time_window": 5, // 组成qa时间窗口(分钟), 29 | "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用 30 | "prompt_with_history": false, // 是否在prompt中包含历史对话 31 | "clean_dataset": { 32 | "enable_clean": false, 33 | "clean_strategy": "llm", 34 | "llm": { 35 | "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练 36 | } 37 | }, 38 | "online_llm_clear": false, 39 | "base_url": "https://xxx/v1", 40 | "llm_api_key": "xxxxx", 41 | "model_name": "xxx", //建议使用参数较大的模型,例如DeepSeek-V3 42 | "clean_batch_size": 10 43 | }, 44 | "train_pt_args": { 45 | //预训练微调配置 46 | "stage": "pt", 47 | "dataset": "wechat-pt", 48 | "dataset_dir": "./dataset/res_csv/pt", 49 | "lora_target": "q_proj,v_proj", 50 | "lora_rank": 2, 51 | "lora_dropout": 0.1, 52 | "output_dir": "model_output", 53 | "overwrite_cache": true, 54 | "per_device_train_batch_size": 1, 55 | "gradient_accumulation_steps": 1, 56 | "lr_scheduler_type": "cosine", 57 | "logging_steps": 10, 58 | "save_steps": 1000, 59 | "learning_rate": 0.001, 60 | "num_train_epochs": 30, 61 | "plot_loss": true, 62 | "fp16": true 63 | }, 64 | "train_sft_args": { 65 | //微调配置 66 | "stage": "sft", 67 | "dataset": "wechat-sft", 68 | "dataset_dir": "./dataset/res_csv/sft", 69 | "use_fast_tokenizer": true, 70 | "lora_target": "q_proj,v_proj", 71 | "lora_rank": 4, 72 | "lora_dropout": 0.3, 73 | "weight_decay": 0.1, 74 | "overwrite_cache": true, 75 | "per_device_train_batch_size": 8, 76 | "gradient_accumulation_steps": 4, 77 | "lr_scheduler_type": "cosine", 78 | "cutoff_len": 256, 79 | "logging_steps": 10, 80 | "save_steps": 100, 81 | "learning_rate": 1e-4, 82 | "warmup_ratio": 0.1, 83 | "num_train_epochs": 2, 84 | "plot_loss": true, 85 | "fp16": true, 86 | "flash_attn": "fa2", 87 | // "deepspeed": "ds_config.json" //多卡训练 88 | }, 89 | "infer_args": { 90 | "repetition_penalty": 1.2, 91 | "temperature": 0.5, 92 | "max_length": 50, 93 | "top_p": 0.65 94 | } 95 | } 96 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/tests/__init__.py -------------------------------------------------------------------------------- /tests/full_pipe.jsonc: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.2", 3 | "common_args": { 4 | "model_name_or_path": "./Qwen2.5-3B-Instruct", 5 | "adapter_name_or_path": "./model_output", //同时做为train_sft_args的output_dir 6 | "template": "qwen", 7 | "default_system": "请你扮演一名人类,不要说自己是人工智能", 8 | "finetuning_type": "lora", 9 | "trust_remote_code": true 10 | }, 11 | "cli_args": { 12 | "full_log": false 13 | }, 14 | "make_dataset_args": { 15 | //数据处理配置 16 | "include_type": [ 17 | "文本" 18 | ], 19 | "blocked_words": [ // 禁用词 20 | "例如 姓名", 21 | "例如 密码", 22 | "//....." 23 | ], 24 | "single_combine_strategy": "time_window", // 单人组成单句策略 25 | "qa_match_strategy": "time_window", // 组成qa策略 26 | "single_combine_time_window": 2, // 单人组成单句时间窗口(分钟), 27 | "qa_match_time_window": 5, // 组成qa时间窗口(分钟), 28 | "combine_msg_max_length": 256, // 组合后消息最大长度 配合cutoff_len 使用 29 | "prompt_with_history": false, // 是否在prompt中包含历史对话 30 | "clean_dataset": { 31 | "enable_clean": true, 32 | "clean_strategy": "llm", 33 | "llm": { 34 | "accept_score": 2, //可以接受的llm打分阈值,1分最差,5分最好,低于此分数的数据不会用于训练 35 | } 36 | } 37 | }, 38 | "train_pt_args": { 39 | //预训练微调配置 40 | "stage": "pt", 41 | "dataset": "wechat-pt", 42 | "dataset_dir": "./dataset/res_csv/pt", 43 | "lora_target": "q_proj,v_proj", 44 | "lora_rank": 2, 45 | "lora_dropout": 0.1, 46 | "output_dir": "model_output", 47 | "overwrite_cache": true, 48 | "per_device_train_batch_size": 1, 49 | "gradient_accumulation_steps": 1, 50 | "lr_scheduler_type": "cosine", 51 | "logging_steps": 10, 52 | "save_steps": 1000, 53 | "learning_rate": 0.001, 54 | "num_train_epochs": 30, 55 | "plot_loss": true, 56 | "fp16": true 57 | }, 58 | "train_sft_args": { 59 | //微调配置 60 | "stage": "sft", 61 | "dataset": "wechat-sft", 62 | "dataset_dir": "./dataset/res_csv/sft", 63 | "use_fast_tokenizer": true, 64 | "lora_target": "q_proj,v_proj", 65 | "lora_rank": 4, 66 | "lora_dropout": 0.3, 67 | "weight_decay": 0.1, 68 | "overwrite_cache": true, 69 | "per_device_train_batch_size": 8, 70 | "gradient_accumulation_steps": 4, 71 | "lr_scheduler_type": "cosine", 72 | "cutoff_len": 256, 73 | "logging_steps": 5, 74 | "save_steps": 10, 75 | "learning_rate": 1e-4, 76 | "warmup_ratio": 0.1, 77 | "num_train_epochs": 1, 78 | "plot_loss": true, 79 | "fp16": true, 80 | "flash_attn": "fa2", 81 | // "deepspeed": "ds_config.json" //多卡训练 82 | }, 83 | "infer_args": { 84 | "repetition_penalty": 1.2, 85 | "temperature": 0.5, 86 | "max_length": 50, 87 | "top_p": 0.65 88 | } 89 | } -------------------------------------------------------------------------------- /tests/test_full_pipe.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from unittest import mock 3 | import sys 4 | import os 5 | import shutil 6 | import functools 7 | import subprocess 8 | import time 9 | from typing import Union, Optional, cast 10 | from weclone.utils.log import logger 11 | 12 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 13 | PROJECT_ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 14 | server_process: Optional[subprocess.Popen] = None 15 | 16 | test_logger = logger.bind() 17 | test_logger.remove() 18 | test_logger.add( 19 | sys.stderr, 20 | format="{message}", 21 | colorize=True, 22 | level="INFO", 23 | ) 24 | 25 | def print_test_header(test_name: str): 26 | line_length = 100 27 | test_logger.info("\n" + "─" * line_length) 28 | title = f" Testing Phase: {test_name} " 29 | padding_total = line_length - len(title) 30 | padding_left = padding_total // 2 31 | padding_right = padding_total - padding_left 32 | test_logger.info(" " * padding_left + title + " " * padding_right) 33 | test_logger.info("─" * line_length) 34 | 35 | def setup_make_dataset_test_data(): 36 | PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) 37 | DATASET_CSV_DIR = os.path.join(PROJECT_ROOT, "dataset", "csv") 38 | 39 | TESTS_DIR = os.path.dirname(__file__) 40 | TEST_DATA_PERSON_DIR = os.path.join(TESTS_DIR, "tests_data", "test_person") 41 | 42 | os.makedirs(DATASET_CSV_DIR, exist_ok=True) 43 | 44 | if os.path.exists(DATASET_CSV_DIR) and os.listdir(DATASET_CSV_DIR): 45 | if all(f.startswith('.') or f.lower() == 'readme.md' for f in os.listdir(DATASET_CSV_DIR)): 46 | for item_name in os.listdir(TEST_DATA_PERSON_DIR): 47 | source_item_path = os.path.join(TEST_DATA_PERSON_DIR, item_name) 48 | if os.path.isfile(source_item_path) and item_name.lower().endswith('.csv'): 49 | destination_item_path = os.path.join(DATASET_CSV_DIR, item_name) 50 | shutil.copy2(source_item_path, destination_item_path) 51 | 52 | 53 | def run_cli_command(command: list[str], timeout: int | None = None, background: bool = False) -> Union[subprocess.CompletedProcess, subprocess.Popen]: 54 | """Execute a CLI command and return the result. 55 | 56 | Args: 57 | command: List of commands to execute. 58 | timeout: Timeout in seconds. 59 | background: Whether to run in the background. 60 | 61 | Returns: 62 | If background=True, returns a Popen object; otherwise, returns a CompletedProcess object. 63 | """ 64 | env = os.environ.copy() 65 | env["WECLONE_CONFIG_PATH"] = "tests/full_pipe.jsonc" # Set environment variable 66 | 67 | if background: 68 | process = subprocess.Popen( 69 | [sys.executable, "-m", "weclone.cli"] + command, 70 | stderr=subprocess.PIPE, 71 | stdout=subprocess.PIPE, 72 | text=True, 73 | cwd=PROJECT_ROOT_DIR, 74 | env=env 75 | ) 76 | time.sleep(2) 77 | return process 78 | else: 79 | process = subprocess.run( 80 | [sys.executable, "-m", "weclone.cli"] + command, 81 | stderr=None, 82 | stdout=None, 83 | text=True, 84 | cwd=PROJECT_ROOT_DIR, # Execute in the project root directory 85 | timeout=timeout, 86 | env=env # Pass the modified environment variables 87 | ) 88 | return process 89 | 90 | @pytest.mark.order(1) 91 | def test_cli_make_dataset(): 92 | """Test the make-dataset command.""" 93 | print_test_header("make-dataset") 94 | setup_make_dataset_test_data() 95 | result = run_cli_command(["make-dataset"]) 96 | assert result.returncode == 0, "make-dataset command execution failed" 97 | 98 | @pytest.mark.order(2) 99 | def test_cli_train_sft(): 100 | """Test the train-sft command.""" 101 | print_test_header("train-sft") 102 | try: 103 | result = run_cli_command(["train-sft"]) 104 | assert result.returncode == 0, "train-sft command failed or did not fail fast as expected" 105 | except subprocess.TimeoutExpired: 106 | test_logger.info("train-sft command terminated due to timeout, which is acceptable in testing, indicating the command has started execution.") 107 | pass 108 | except Exception as e: 109 | pytest.fail(f"An unexpected error occurred during train-sft command execution: {e}") 110 | 111 | @pytest.mark.order(3) 112 | def test_cli_webchat_demo(): 113 | """Test the webchat-demo command.""" 114 | print_test_header("webchat-demo") 115 | 116 | with mock.patch("weclone.eval.web_demo.main") as mock_main: 117 | mock_main.return_value = None 118 | try: 119 | result = run_cli_command(["webchat-demo"], timeout=5) 120 | assert result.returncode == 0, "webchat-demo command execution failed" 121 | except subprocess.TimeoutExpired: 122 | pass 123 | 124 | @pytest.mark.order(4) 125 | def test_cli_server(): 126 | """Test the server command. 127 | 128 | Start the server in the background, without blocking subsequent tests. 129 | """ 130 | print_test_header("server (background)") 131 | global server_process 132 | server_process = cast(subprocess.Popen, run_cli_command(["server"], background=True)) 133 | assert server_process.poll() is None, "Server startup failed" 134 | test_logger.info("服务器已在后台启动") 135 | 136 | @pytest.mark.order(5) 137 | def test_cli_test_model(): 138 | """Test the test-model command. 139 | 140 | Use the server for testing, and shut down the server after the test is complete. 141 | """ 142 | print_test_header("test-model") 143 | try: 144 | result = run_cli_command(["test-model"]) 145 | assert result.returncode == 0, "test-model command execution failed" 146 | finally: 147 | global server_process 148 | if server_process is not None and server_process.poll() is None: 149 | test_logger.info("测试完成,正在关闭服务器...") 150 | server_process.terminate() 151 | server_process.wait(timeout=5) 152 | if server_process.poll() is None: 153 | server_process.kill() # Force kill if the process hasn't terminated 154 | test_logger.info("服务器已关闭") 155 | -------------------------------------------------------------------------------- /weclone-audio/README.md: -------------------------------------------------------------------------------- 1 | # WeClone-audio 模块 2 | 3 | WeClone-audio 是一个使用微信语音消息克隆声音的模块,使用模型实现高质量语音合成。 4 | ### 显存需求 5 | **Spark-TTS** 推荐 6 | - **0.5B 模型**: 约 4GB 显存 7 | 8 | **Llasa** (已弃用) 9 | - **3B 模型**: 约 16GB 显存 10 | - **1B 模型**: 约 9GB 显存 11 | 12 | 13 | 14 | 15 | ## 1. 导出微信语音数据 16 | 17 | ### 1.1 准备工作 18 | - 使用 [PyWxDump](https://github.com/xaoyaoo/PyWxDump) 提取微信聊天记录 19 | - 下载软件并解密数据库 20 | - 点击聊天备份,导出类型选择"解密文件" 21 | 22 | ### 1.2 环境配置 23 | 语音导出仅支持Windows环境 24 | WeClone Audio使用uv作为包管理器。 25 | ```bash 26 | # 为 PyWxDump 创建 Python 环境和安装依赖 27 | # 28 | uv venv .venv-wx --python=3.10 29 | .venv-wx\Scripts\activate 30 | uv pip install pywxdump 31 | ``` 32 | 33 | ### 1.3 导出语音文件 34 | ```bash 35 | python weclone-audio/src/get_sample_audio.py --db-path "导出数据库路径" --MsgSvrID "导出聊天记录的MsgSvrID字段" 36 | ``` 37 | 38 | ## 2. 语音合成推理 39 | ### Spark-TTS模型 40 | 41 | **环境安装** 42 | 可不创建新环境,直接安装`sparktts`依赖组到WeClone共主环境 43 | 44 | ```bash 45 | uv venv .venv-sparktts --python=3.10 46 | source .venv-sparktts/bin/activate 47 | uv pip install --group sparktts -e . 48 | 49 | git clone https://github.com/SparkAudio/Spark-TTS.git weclone-audio/src/Spark-TTS 50 | ``` 51 | 52 | 53 | **模型下载** 54 | 55 | 通过python下载: 56 | ```python 57 | from huggingface_hub import snapshot_download 58 | 59 | # 假设此 Python 代码在 weclone-audio 目录下运行 模型将下载到 weclone-audio/pretrained_models/Spark-TTS-0.5B 60 | snapshot_download("SparkAudio/Spark-TTS-0.5B", local_dir="pretrained_models/Spark-TTS-0.5B") 61 | ``` 62 | 63 | 或通过git下载: 64 | ```bash 65 | # 假设当前在 weclone-audio 目录 66 | mkdir -p pretrained_models 67 | 68 | # Make sure you have git-lfs installed (https://git-lfs.com) 69 | git lfs install 70 | git clone https://huggingface.co/SparkAudio/Spark-TTS-0.5B pretrained_models/Spark-TTS-0.5B 71 | ``` 72 | 使用代码推理 73 | ```python 74 | import os 75 | import SparkTTS 76 | import soundfile as sf 77 | import torch 78 | 79 | from SparkTTS import SparkTTS 80 | 81 | # 假设此 Python 代码在 weclone-audio 目录下运行 82 | # 模型路径相对于当前目录 83 | model_path = "pretrained_models/Spark-TTS-0.5B" 84 | sample_audio = "sample.wav" 85 | output_audio = "output.wav" 86 | 87 | model = SparkTTS(model_path, "cuda") 88 | 89 | with torch.no_grad(): 90 | wav = model.inference( 91 | text="晚上好啊,小可爱们,该睡觉了哦", 92 | prompt_speech_path=sample_audio, # 使用相对路径 93 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", 94 | ) 95 | sf.write(output_audio, wav, samplerate=16000) # 使用相对路径 96 | ``` 97 | ### Llasa模型 (已弃用) 98 | ### 2.1 环境配置 99 | ```bash 100 | # 创建并配置推理环境 101 | ## 可不创建新环境,与LLaMA-Factory环境共用 102 | uv venv .venv-xcodec --python=3.9 103 | source .venv-xcodec/bin/activate 104 | uv pip install --group xcodec -e . 105 | # 退出环境 106 | deactivate 107 | 108 | # 系统依赖安装(如果需要) 109 | sudo apt install python3-dev 110 | sudo apt install build-essential 111 | ``` 112 | 113 | ### 2.2 使用代码推理 114 | 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 115 | ```python 116 | import os 117 | import soundfile as sf 118 | # 假设 text_to_speech.py 位于 src/ 或其他可导入的位置 119 | from text_to_speech import TextToSpeech 120 | 121 | 122 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 123 | # 假设此 Python 代码在 weclone-audio 目录下运行 124 | # 示例音频路径相对于当前目录 125 | sample_audio_path = "sample.wav" 126 | output_audio = "output.wav" 127 | 128 | 129 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 130 | target_text = "晚上好啊" # 生成目标文本 131 | result = tts.infer(target_text) 132 | sf.write(output_audio, result[1], result[0]) # 使用相对路径 133 | ``` 134 | 135 | -------------------------------------------------------------------------------- /weclone-audio/src/Llasa/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import soundfile as sf 3 | from text_to_speech import TextToSpeech 4 | 5 | 6 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" # 示例音频文本 7 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") # 示例音频路径 8 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 9 | target_text = "晚上好啊" # 生成目标文本 10 | result = tts.infer(target_text) 11 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) # 保存生成音频 12 | 13 | -------------------------------------------------------------------------------- /weclone-audio/src/Llasa/text_to_speech.py: -------------------------------------------------------------------------------- 1 | import os 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import torch 4 | import soundfile as sf 5 | from xcodec2.modeling_xcodec2 import XCodec2Model 6 | import torchaudio 7 | 8 | 9 | class TextToSpeech: 10 | def __init__(self, sample_audio_path, sample_audio_text): 11 | self.sample_audio_text = sample_audio_text 12 | # 初始化模型 13 | llasa_3b = "HKUSTAudio/Llasa-3B" 14 | xcodec2 = "HKUSTAudio/xcodec2" 15 | 16 | self.tokenizer = AutoTokenizer.from_pretrained(llasa_3b) 17 | self.llasa_3b_model = AutoModelForCausalLM.from_pretrained( 18 | llasa_3b, 19 | trust_remote_code=True, 20 | device_map="auto", 21 | ) 22 | self.llasa_3b_model.eval() 23 | 24 | self.xcodec_model = XCodec2Model.from_pretrained(xcodec2) 25 | self.xcodec_model.eval().cuda() 26 | 27 | # 处理音频 28 | waveform, sample_rate = torchaudio.load(sample_audio_path) 29 | if len(waveform[0]) / sample_rate > 15: 30 | print("已将音频裁剪至前15秒。") 31 | waveform = waveform[:, : sample_rate * 15] 32 | 33 | # 检查音频是否为立体声 34 | if waveform.size(0) > 1: 35 | waveform_mono = torch.mean(waveform, dim=0, keepdim=True) 36 | else: 37 | waveform_mono = waveform 38 | 39 | self.prompt_wav = torchaudio.transforms.Resample( 40 | orig_freq=sample_rate, new_freq=16000 41 | )(waveform_mono) 42 | 43 | # Encode the prompt wav 44 | vq_code_prompt = self.xcodec_model.encode_code(input_waveform=self.prompt_wav) 45 | vq_code_prompt = vq_code_prompt[0, 0, :] 46 | self.speech_ids_prefix = self.ids_to_speech_tokens(vq_code_prompt) 47 | self.speech_end_id = self.tokenizer.convert_tokens_to_ids("<|SPEECH_GENERATION_END|>") 48 | 49 | def ids_to_speech_tokens(self, speech_ids): 50 | speech_tokens_str = [] 51 | for speech_id in speech_ids: 52 | speech_tokens_str.append(f"<|s_{speech_id}|>") 53 | return speech_tokens_str 54 | 55 | def extract_speech_ids(self, speech_tokens_str): 56 | speech_ids = [] 57 | for token_str in speech_tokens_str: 58 | if token_str.startswith("<|s_") and token_str.endswith("|>"): 59 | num_str = token_str[4:-2] 60 | num = int(num_str) 61 | speech_ids.append(num) 62 | else: 63 | print(f"Unexpected token: {token_str}") 64 | return speech_ids 65 | 66 | @torch.inference_mode() 67 | def infer(self, target_text): 68 | if len(target_text) == 0: 69 | return None 70 | elif len(target_text) > 300: 71 | print("文本过长,请保持在300字符以内。") 72 | target_text = target_text[:300] 73 | 74 | input_text = self.sample_audio_text + " " + target_text 75 | 76 | formatted_text = ( 77 | f"<|TEXT_UNDERSTANDING_START|>{input_text}<|TEXT_UNDERSTANDING_END|>" 78 | ) 79 | 80 | chat = [ 81 | { 82 | "role": "user", 83 | "content": "Convert the text to speech:" + formatted_text, 84 | }, 85 | { 86 | "role": "assistant", 87 | "content": "<|SPEECH_GENERATION_START|>" 88 | + "".join(self.speech_ids_prefix), 89 | }, 90 | ] 91 | 92 | input_ids = self.tokenizer.apply_chat_template( 93 | chat, tokenize=True, return_tensors="pt", continue_final_message=True 94 | ) 95 | input_ids = input_ids.to("cuda") 96 | 97 | outputs = self.llasa_3b_model.generate( 98 | input_ids, 99 | max_length=2048, 100 | eos_token_id=self.speech_end_id, 101 | do_sample=True, 102 | top_p=1, 103 | temperature=0.8, 104 | ) 105 | generated_ids = outputs[0][input_ids.shape[1] - len(self.speech_ids_prefix): -1] 106 | 107 | speech_tokens = self.tokenizer.batch_decode( 108 | generated_ids, skip_special_tokens=True 109 | ) 110 | 111 | speech_tokens = self.extract_speech_ids(speech_tokens) 112 | speech_tokens = torch.tensor(speech_tokens).cuda().unsqueeze(0).unsqueeze(0) 113 | 114 | gen_wav = self.xcodec_model.decode_code(speech_tokens) 115 | gen_wav = gen_wav[:, :, self.prompt_wav.shape[1]:] 116 | 117 | return (16000, gen_wav[0, 0, :].cpu().numpy()) 118 | 119 | 120 | if __name__ == "__main__": 121 | # 如果遇到问题,请尝试将参考音频转换为WAV或MP3格式,将其裁剪至15秒以内,并缩短提示文本。 122 | sample_audio_text = "对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。" 123 | sample_audio_path = os.path.join(os.path.dirname(__file__), "sample.wav") 124 | 125 | tts = TextToSpeech(sample_audio_path, sample_audio_text) 126 | target_text = "晚上好啊,吃了吗您" 127 | result = tts.infer(target_text) 128 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), result[1], result[0]) 129 | target_text = "我是老北京正黄旗!" 130 | result = tts.infer(target_text) 131 | sf.write(os.path.join(os.path.dirname(__file__), "output1.wav"), result[1], result[0]) 132 | -------------------------------------------------------------------------------- /weclone-audio/src/SparkTTS.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | from typing import Tuple 4 | from pathlib import Path 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | import os 7 | import sys 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "./Spark-TTS"))) 9 | from sparktts.utils.file import load_config 10 | from sparktts.models.audio_tokenizer import BiCodecTokenizer 11 | from sparktts.utils.token_parser import LEVELS_MAP, GENDER_MAP, TASK_TOKEN_MAP 12 | 13 | 14 | class SparkTTS: 15 | """ 16 | Spark-TTS for text-to-speech generation. 17 | """ 18 | 19 | def __init__(self, model_dir: Path, device: torch.device = torch.device("cuda:0")): 20 | """ 21 | Initializes the SparkTTS model with the provided configurations and device. 22 | 23 | Args: 24 | model_dir (Path): Directory containing the model and config files. 25 | device (torch.device): The device (CPU/GPU) to run the model on. 26 | """ 27 | self.device = device 28 | self.model_dir = model_dir 29 | self.configs = load_config(f"{model_dir}/config.yaml") 30 | self.sample_rate = self.configs["sample_rate"] 31 | self._initialize_inference() 32 | 33 | def _initialize_inference(self): 34 | """Initializes the tokenizer, model, and audio tokenizer for inference.""" 35 | self.tokenizer = AutoTokenizer.from_pretrained(f"{self.model_dir}/LLM") 36 | self.model = AutoModelForCausalLM.from_pretrained(f"{self.model_dir}/LLM") 37 | self.audio_tokenizer = BiCodecTokenizer(self.model_dir, device=self.device) 38 | self.model.to(self.device) 39 | 40 | def process_prompt( 41 | self, 42 | text: str, 43 | prompt_speech_path: Path, 44 | prompt_text: str = None, 45 | ) -> Tuple[str, torch.Tensor]: 46 | """ 47 | Process input for voice cloning. 48 | 49 | Args: 50 | text (str): The text input to be converted to speech. 51 | prompt_speech_path (Path): Path to the audio file used as a prompt. 52 | prompt_text (str, optional): Transcript of the prompt audio. 53 | 54 | Return: 55 | Tuple[str, torch.Tensor]: Input prompt; global tokens 56 | """ 57 | 58 | global_token_ids, semantic_token_ids = self.audio_tokenizer.tokenize( 59 | prompt_speech_path 60 | ) 61 | global_tokens = "".join( 62 | [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()] 63 | ) 64 | 65 | # Prepare the input tokens for the model 66 | if prompt_text is not None: 67 | semantic_tokens = "".join( 68 | [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()] 69 | ) 70 | inputs = [ 71 | TASK_TOKEN_MAP["tts"], 72 | "<|start_content|>", 73 | prompt_text, 74 | text, 75 | "<|end_content|>", 76 | "<|start_global_token|>", 77 | global_tokens, 78 | "<|end_global_token|>", 79 | "<|start_semantic_token|>", 80 | semantic_tokens, 81 | ] 82 | else: 83 | inputs = [ 84 | TASK_TOKEN_MAP["tts"], 85 | "<|start_content|>", 86 | text, 87 | "<|end_content|>", 88 | "<|start_global_token|>", 89 | global_tokens, 90 | "<|end_global_token|>", 91 | ] 92 | 93 | inputs = "".join(inputs) 94 | 95 | return inputs, global_token_ids 96 | 97 | def process_prompt_control( 98 | self, 99 | gender: str, 100 | pitch: str, 101 | speed: str, 102 | text: str, 103 | ): 104 | """ 105 | Process input for voice creation. 106 | 107 | Args: 108 | gender (str): female | male. 109 | pitch (str): very_low | low | moderate | high | very_high 110 | speed (str): very_low | low | moderate | high | very_high 111 | text (str): The text input to be converted to speech. 112 | 113 | Return: 114 | str: Input prompt 115 | """ 116 | assert gender in GENDER_MAP.keys() 117 | assert pitch in LEVELS_MAP.keys() 118 | assert speed in LEVELS_MAP.keys() 119 | 120 | gender_id = GENDER_MAP[gender] 121 | pitch_level_id = LEVELS_MAP[pitch] 122 | speed_level_id = LEVELS_MAP[speed] 123 | 124 | pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>" 125 | speed_label_tokens = f"<|speed_label_{speed_level_id}|>" 126 | gender_tokens = f"<|gender_{gender_id}|>" 127 | 128 | attribte_tokens = "".join( 129 | [gender_tokens, pitch_label_tokens, speed_label_tokens] 130 | ) 131 | 132 | control_tts_inputs = [ 133 | TASK_TOKEN_MAP["controllable_tts"], 134 | "<|start_content|>", 135 | text, 136 | "<|end_content|>", 137 | "<|start_style_label|>", 138 | attribte_tokens, 139 | "<|end_style_label|>", 140 | ] 141 | 142 | return "".join(control_tts_inputs) 143 | 144 | @torch.no_grad() 145 | def inference( 146 | self, 147 | text: str, 148 | prompt_speech_path: Path = None, 149 | prompt_text: str = None, 150 | gender: str = None, 151 | pitch: str = None, 152 | speed: str = None, 153 | temperature: float = 0.8, 154 | top_k: float = 50, 155 | top_p: float = 0.95, 156 | ) -> torch.Tensor: 157 | """ 158 | Performs inference to generate speech from text, incorporating prompt audio and/or text. 159 | 160 | Args: 161 | text (str): The text input to be converted to speech. 162 | prompt_speech_path (Path): Path to the audio file used as a prompt. 163 | prompt_text (str, optional): Transcript of the prompt audio. 164 | gender (str): female | male. 165 | pitch (str): very_low | low | moderate | high | very_high 166 | speed (str): very_low | low | moderate | high | very_high 167 | temperature (float, optional): Sampling temperature for controlling randomness. Default is 0.8. 168 | top_k (float, optional): Top-k sampling parameter. Default is 50. 169 | top_p (float, optional): Top-p (nucleus) sampling parameter. Default is 0.95. 170 | 171 | Returns: 172 | torch.Tensor: Generated waveform as a tensor. 173 | """ 174 | if gender is not None: 175 | prompt = self.process_prompt_control(gender, pitch, speed, text) 176 | 177 | else: 178 | prompt, global_token_ids = self.process_prompt( 179 | text, prompt_speech_path, prompt_text 180 | ) 181 | model_inputs = self.tokenizer([prompt], return_tensors="pt").to(self.device) 182 | 183 | # Generate speech using the model 184 | generated_ids = self.model.generate( 185 | **model_inputs, 186 | max_new_tokens=3000, 187 | do_sample=True, 188 | top_k=top_k, 189 | top_p=top_p, 190 | temperature=temperature, 191 | ) 192 | 193 | # Trim the output tokens to remove the input tokens 194 | generated_ids = [ 195 | output_ids[len(input_ids):] 196 | for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) 197 | ] 198 | 199 | # Decode the generated tokens into text 200 | predicts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] 201 | 202 | # Extract semantic token IDs from the generated text 203 | pred_semantic_ids = ( 204 | torch.tensor([int(token) for token in re.findall(r"bicodec_semantic_(\d+)", predicts)]) 205 | .long() 206 | .unsqueeze(0) 207 | ) 208 | 209 | if gender is not None: 210 | global_token_ids = ( 211 | torch.tensor([int(token) for token in re.findall(r"bicodec_global_(\d+)", predicts)]) 212 | .long() 213 | .unsqueeze(0) 214 | .unsqueeze(0) 215 | ) 216 | 217 | # Convert semantic tokens back to waveform 218 | wav = self.audio_tokenizer.detokenize( 219 | global_token_ids.to(self.device).squeeze(0), 220 | pred_semantic_ids.to(self.device), 221 | ) 222 | 223 | return wav 224 | -------------------------------------------------------------------------------- /weclone-audio/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone-audio/src/__init__.py -------------------------------------------------------------------------------- /weclone-audio/src/get_sample_audio.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from pywxdump.db import MediaHandler 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser(description="Extract audio from WeChat database") 7 | parser.add_argument("--db-path", type=str, required=True, 8 | help="Path to WeChat database file") 9 | parser.add_argument("--MsgSvrID", type=str, required=True, 10 | help="Message server ID of the audio") 11 | parser.add_argument("--save-path", type=str, 12 | default=os.path.join(os.path.dirname(__file__), "sample.wav"), 13 | help="Path to save the audio file (default: sample.wav in script directory)") 14 | parser.add_argument("--rate", type=int, default=24000, 15 | help="Sample rate for audio conversion (default: 24000)") 16 | 17 | args = parser.parse_args() 18 | 19 | config = { 20 | "key": "test1", 21 | "type": "sqlite", 22 | "path": args.db_path, 23 | } 24 | 25 | t1 = MediaHandler(config) 26 | t1.get_audio( 27 | MsgSvrID=args.MsgSvrID, 28 | is_play=True, 29 | is_wave=True, 30 | save_path=args.save_path, 31 | rate=args.rate, 32 | ) 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /weclone-audio/src/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import soundfile as sf 3 | import torch 4 | 5 | from SparkTTS import SparkTTS 6 | 7 | model = SparkTTS("weclone-audio/pretrained_models/Spark-TTS-0.5B", "cuda") 8 | 9 | 10 | with torch.no_grad(): 11 | wav = model.inference( 12 | text="晚上好啊,小可爱们,该睡觉了哦", 13 | prompt_speech_path=os.path.join(os.path.dirname(__file__), "sample.wav"), 14 | prompt_text="对,这就是我万人敬仰的太乙真人,虽然有点婴儿肥,但也掩不住我逼人的帅气。", 15 | ) 16 | sf.write(os.path.join(os.path.dirname(__file__), "output.wav"), wav, samplerate=16000) 17 | print("生成成功!") 18 | -------------------------------------------------------------------------------- /weclone-audio/src/sample.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone-audio/src/sample.wav -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/.env.example: -------------------------------------------------------------------------------- 1 | API_KEY=your_api_key_here 2 | PORT=5050 3 | 4 | DEFAULT_VOICE=en-US-AvaNeural 5 | DEFAULT_RESPONSE_FORMAT=mp3 6 | DEFAULT_SPEED=1.0 7 | 8 | DEFAULT_LANGUAGE=en-US 9 | 10 | REQUIRE_API_KEY=True 11 | 12 | REMOVE_FILTER=False 13 | 14 | EXPAND_API=True -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/handle_text.py: -------------------------------------------------------------------------------- 1 | import re 2 | import emoji 3 | 4 | def prepare_tts_input_with_context(text: str) -> str: 5 | """ 6 | Prepares text for a TTS API by cleaning Markdown and adding minimal contextual hints 7 | for certain Markdown elements like headers. Preserves paragraph separation. 8 | 9 | Args: 10 | text (str): The raw text containing Markdown or other formatting. 11 | 12 | Returns: 13 | str: Cleaned text with contextual hints suitable for TTS input. 14 | """ 15 | 16 | # Remove emojis 17 | text = emoji.replace_emoji(text, replace='') 18 | 19 | # Add context for headers 20 | def header_replacer(match): 21 | level = len(match.group(1)) # Number of '#' symbols 22 | header_text = match.group(2).strip() 23 | if level == 1: 24 | return f"Title — {header_text}\n" 25 | elif level == 2: 26 | return f"Section — {header_text}\n" 27 | else: 28 | return f"Subsection — {header_text}\n" 29 | 30 | text = re.sub(r"^(#{1,6})\s+(.*)", header_replacer, text, flags=re.MULTILINE) 31 | 32 | # Announce links (currently commented out for potential future use) 33 | # text = re.sub(r"\[([^\]]+)\]\((https?:\/\/[^\)]+)\)", r"\1 (link: \2)", text) 34 | 35 | # Remove links while keeping the link text 36 | text = re.sub(r"\[([^\]]+)\]\([^\)]+\)", r"\1", text) 37 | 38 | # Describe inline code 39 | text = re.sub(r"`([^`]+)`", r"code snippet: \1", text) 40 | 41 | # Remove bold/italic symbols but keep the content 42 | text = re.sub(r"(\*\*|__|\*|_)", '', text) 43 | 44 | # Remove code blocks (multi-line) with a description 45 | text = re.sub(r"```([\s\S]+?)```", r"(code block omitted)", text) 46 | 47 | # Remove image syntax but add alt text if available 48 | text = re.sub(r"!\[([^\]]*)\]\([^\)]+\)", r"Image: \1", text) 49 | 50 | # Remove HTML tags 51 | text = re.sub(r"]+(>|$)", '', text) 52 | 53 | # Normalize line breaks 54 | text = re.sub(r"\n{2,}", '\n\n', text) # Ensure consistent paragraph separation 55 | 56 | # Replace multiple spaces within lines 57 | text = re.sub(r" {2,}", ' ', text) 58 | 59 | # Trim leading and trailing whitespace from the whole text 60 | text = text.strip() 61 | 62 | return text 63 | -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/requirements.txt: -------------------------------------------------------------------------------- 1 | flask 2 | gevent 3 | python-dotenv 4 | edge-tts 5 | emoji -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/server.py: -------------------------------------------------------------------------------- 1 | # server.py 2 | 3 | from flask import Flask, request, send_file, jsonify 4 | from gevent.pywsgi import WSGIServer 5 | from dotenv import load_dotenv 6 | import os 7 | 8 | from handle_text import prepare_tts_input_with_context 9 | from tts_handler import generate_speech, get_models, get_voices 10 | from utils import getenv_bool, require_api_key, AUDIO_FORMAT_MIME_TYPES 11 | 12 | app = Flask(__name__) 13 | load_dotenv() 14 | 15 | API_KEY = os.getenv('API_KEY', 'your_api_key_here') 16 | PORT = int(os.getenv('PORT', 5050)) 17 | 18 | DEFAULT_VOICE = os.getenv('DEFAULT_VOICE', 'en-US-AvaNeural') 19 | DEFAULT_RESPONSE_FORMAT = os.getenv('DEFAULT_RESPONSE_FORMAT', 'mp3') 20 | DEFAULT_SPEED = float(os.getenv('DEFAULT_SPEED', 1.0)) 21 | 22 | REMOVE_FILTER = getenv_bool('REMOVE_FILTER', False) 23 | EXPAND_API = getenv_bool('EXPAND_API', True) 24 | 25 | # DEFAULT_MODEL = os.getenv('DEFAULT_MODEL', 'tts-1') 26 | 27 | @app.route('/v1/audio/speech', methods=['POST']) 28 | @app.route('/audio/speech', methods=['POST']) # Add this line for the alias 29 | @require_api_key 30 | def text_to_speech(): 31 | data = request.json 32 | if not data or 'input' not in data: 33 | return jsonify({"error": "Missing 'input' in request body"}), 400 34 | 35 | text = data.get('input') 36 | 37 | if not REMOVE_FILTER: 38 | text = prepare_tts_input_with_context(text) 39 | 40 | # model = data.get('model', DEFAULT_MODEL) 41 | voice = data.get('voice', DEFAULT_VOICE) 42 | 43 | response_format = data.get('response_format', DEFAULT_RESPONSE_FORMAT) 44 | speed = float(data.get('speed', DEFAULT_SPEED)) 45 | 46 | mime_type = AUDIO_FORMAT_MIME_TYPES.get(response_format, "audio/mpeg") 47 | 48 | # Generate the audio file in the specified format with speed adjustment 49 | output_file_path = generate_speech(text, voice, response_format, speed) 50 | 51 | # Return the file with the correct MIME type 52 | return send_file(output_file_path, mimetype=mime_type, as_attachment=True, download_name=f"speech.{response_format}") 53 | 54 | @app.route('/v1/models', methods=['GET', 'POST']) 55 | @app.route('/models', methods=['GET', 'POST']) 56 | @require_api_key 57 | def list_models(): 58 | return jsonify({"data": get_models()}) 59 | 60 | @app.route('/v1/voices', methods=['GET', 'POST']) 61 | @app.route('/voices', methods=['GET', 'POST']) 62 | @require_api_key 63 | def list_voices(): 64 | specific_language = None 65 | 66 | data = request.args if request.method == 'GET' else request.json 67 | if data and ('language' in data or 'locale' in data): 68 | specific_language = data.get('language') if 'language' in data else data.get('locale') 69 | 70 | return jsonify({"voices": get_voices(specific_language)}) 71 | 72 | @app.route('/v1/voices/all', methods=['GET', 'POST']) 73 | @app.route('/voices/all', methods=['GET', 'POST']) 74 | @require_api_key 75 | def list_all_voices(): 76 | return jsonify({"voices": get_voices('all')}) 77 | 78 | """ 79 | Support for ElevenLabs and Azure AI Speech 80 | (currently in beta) 81 | """ 82 | 83 | # http://localhost:5050/elevenlabs/v1/text-to-speech 84 | # http://localhost:5050/elevenlabs/v1/text-to-speech/en-US-AndrewNeural 85 | @app.route('/elevenlabs/v1/text-to-speech/', methods=['POST']) 86 | @require_api_key 87 | def elevenlabs_tts(voice_id): 88 | if not EXPAND_API: 89 | return jsonify({"error": f"Endpoint not allowed"}), 500 90 | 91 | # Parse the incoming JSON payload 92 | try: 93 | payload = request.json 94 | if not payload or 'text' not in payload: 95 | return jsonify({"error": "Missing 'text' in request body"}), 400 96 | except Exception as e: 97 | return jsonify({"error": f"Invalid JSON payload: {str(e)}"}), 400 98 | 99 | text = payload['text'] 100 | 101 | if not REMOVE_FILTER: 102 | text = prepare_tts_input_with_context(text) 103 | 104 | voice = voice_id # ElevenLabs uses the voice_id in the URL 105 | 106 | # Use default settings for edge-tts 107 | response_format = 'mp3' 108 | speed = DEFAULT_SPEED # Optional customization via payload.get('speed', DEFAULT_SPEED) 109 | 110 | # Generate speech using edge-tts 111 | try: 112 | output_file_path = generate_speech(text, voice, response_format, speed) 113 | except Exception as e: 114 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 115 | 116 | # Return the generated audio file 117 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") 118 | 119 | # tts.speech.microsoft.com/cognitiveservices/v1 120 | # https://{region}.tts.speech.microsoft.com/cognitiveservices/v1 121 | # http://localhost:5050/azure/cognitiveservices/v1 122 | @app.route('/azure/cognitiveservices/v1', methods=['POST']) 123 | @require_api_key 124 | def azure_tts(): 125 | if not EXPAND_API: 126 | return jsonify({"error": f"Endpoint not allowed"}), 500 127 | 128 | # Parse the SSML payload 129 | try: 130 | ssml_data = request.data.decode('utf-8') 131 | if not ssml_data: 132 | return jsonify({"error": "Missing SSML payload"}), 400 133 | 134 | # Extract the text and voice from SSML 135 | from xml.etree import ElementTree as ET 136 | root = ET.fromstring(ssml_data) 137 | text = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').text 138 | voice = root.find('.//{http://www.w3.org/2001/10/synthesis}voice').get('name') 139 | except Exception as e: 140 | return jsonify({"error": f"Invalid SSML payload: {str(e)}"}), 400 141 | 142 | # Use default settings for edge-tts 143 | response_format = 'mp3' 144 | speed = DEFAULT_SPEED 145 | 146 | if not REMOVE_FILTER: 147 | text = prepare_tts_input_with_context(text) 148 | 149 | # Generate speech using edge-tts 150 | try: 151 | output_file_path = generate_speech(text, voice, response_format, speed) 152 | except Exception as e: 153 | return jsonify({"error": f"TTS generation failed: {str(e)}"}), 500 154 | 155 | # Return the generated audio file 156 | return send_file(output_file_path, mimetype="audio/mpeg", as_attachment=True, download_name="speech.mp3") 157 | 158 | print(f" Edge TTS (Free Azure TTS) Replacement for OpenAI's TTS API") 159 | print(f" ") 160 | print(f" * Serving OpenAI Edge TTS") 161 | print(f" * Server running on http://localhost:{PORT}") 162 | print(f" * TTS Endpoint: http://localhost:{PORT}/v1/audio/speech") 163 | print(f" ") 164 | 165 | if __name__ == '__main__': 166 | http_server = WSGIServer(('0.0.0.0', PORT), app) 167 | http_server.serve_forever() 168 | -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/tts_handler.py: -------------------------------------------------------------------------------- 1 | import edge_tts 2 | import asyncio 3 | import tempfile 4 | import subprocess 5 | import os 6 | from pathlib import Path 7 | 8 | # Language default (environment variable) 9 | DEFAULT_LANGUAGE = os.getenv('DEFAULT_LANGUAGE', 'en-US') 10 | 11 | # OpenAI voice names mapped to edge-tts equivalents 12 | voice_mapping = { 13 | 'alloy': 'en-US-AvaNeural', 14 | 'echo': 'en-US-AndrewNeural', 15 | 'fable': 'en-GB-SoniaNeural', 16 | 'onyx': 'en-US-EricNeural', 17 | 'nova': 'en-US-SteffanNeural', 18 | 'shimmer': 'en-US-EmmaNeural' 19 | } 20 | 21 | def is_ffmpeg_installed(): 22 | """Check if FFmpeg is installed and accessible.""" 23 | try: 24 | subprocess.run(['ffmpeg', '-version'], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 25 | return True 26 | except (subprocess.CalledProcessError, FileNotFoundError): 27 | return False 28 | 29 | async def _generate_audio(text, voice, response_format, speed): 30 | """Generate TTS audio and optionally convert to a different format.""" 31 | # Determine if the voice is an OpenAI-compatible voice or a direct edge-tts voice 32 | edge_tts_voice = voice_mapping.get(voice, voice) # Use mapping if in OpenAI names, otherwise use as-is 33 | 34 | # Generate the TTS output in mp3 format first 35 | temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") 36 | 37 | # Convert speed to SSML rate format 38 | try: 39 | speed_rate = speed_to_rate(speed) # Convert speed value to "+X%" or "-X%" 40 | except Exception as e: 41 | print(f"Error converting speed: {e}. Defaulting to +0%.") 42 | speed_rate = "+0%" 43 | 44 | # Generate the MP3 file 45 | communicator = edge_tts.Communicate(text=text, voice=edge_tts_voice, rate=speed_rate) 46 | await communicator.save(temp_output_file.name) 47 | 48 | # If the requested format is mp3, return the generated file directly 49 | if response_format == "mp3": 50 | return temp_output_file.name 51 | 52 | # Check if FFmpeg is installed 53 | if not is_ffmpeg_installed(): 54 | print("FFmpeg is not available. Returning unmodified mp3 file.") 55 | return temp_output_file.name 56 | 57 | # Create a new temporary file for the converted output 58 | converted_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=f".{response_format}") 59 | 60 | # Build the FFmpeg command 61 | ffmpeg_command = [ 62 | "ffmpeg", 63 | "-i", temp_output_file.name, # Input file 64 | "-c:a", { 65 | "aac": "aac", 66 | "mp3": "libmp3lame", 67 | "wav": "pcm_s16le", 68 | "opus": "libopus", 69 | "flac": "flac" 70 | }.get(response_format, "aac"), # Default to AAC if unknown 71 | "-b:a", "192k" if response_format != "wav" else None, # Bitrate not needed for WAV 72 | "-f", { 73 | "aac": "mp4", # AAC in MP4 container 74 | "mp3": "mp3", 75 | "wav": "wav", 76 | "opus": "ogg", 77 | "flac": "flac" 78 | }.get(response_format, response_format), # Default to matching format 79 | "-y", # Overwrite without prompt 80 | converted_output_file.name # Output file 81 | ] 82 | 83 | try: 84 | # Run FFmpeg command and ensure no errors occur 85 | subprocess.run(ffmpeg_command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 86 | except subprocess.CalledProcessError as e: 87 | raise RuntimeError(f"FFmpeg error during audio conversion: {e}") 88 | 89 | # Clean up the original temporary file 90 | Path(temp_output_file.name).unlink(missing_ok=True) 91 | 92 | return converted_output_file.name 93 | 94 | def generate_speech(text, voice, response_format, speed=1.0): 95 | return asyncio.run(_generate_audio(text, voice, response_format, speed)) 96 | 97 | def get_models(): 98 | return [ 99 | {"id": "tts-1", "name": "Text-to-speech v1"}, 100 | {"id": "tts-1-hd", "name": "Text-to-speech v1 HD"} 101 | ] 102 | 103 | async def _get_voices(language=None): 104 | # List all voices, filter by language if specified 105 | all_voices = await edge_tts.list_voices() 106 | language = language or DEFAULT_LANGUAGE # Use default if no language specified 107 | filtered_voices = [ 108 | {"name": v['ShortName'], "gender": v['Gender'], "language": v['Locale']} 109 | for v in all_voices if language == 'all' or language is None or v['Locale'] == language 110 | ] 111 | return filtered_voices 112 | 113 | def get_voices(language=None): 114 | return asyncio.run(_get_voices(language)) 115 | 116 | def speed_to_rate(speed: float) -> str: 117 | """ 118 | Converts a multiplicative speed value to the edge-tts "rate" format. 119 | 120 | Args: 121 | speed (float): The multiplicative speed value (e.g., 1.5 for +50%, 0.5 for -50%). 122 | 123 | Returns: 124 | str: The formatted "rate" string (e.g., "+50%" or "-50%"). 125 | """ 126 | if speed < 0 or speed > 2: 127 | raise ValueError("Speed must be between 0 and 2 (inclusive).") 128 | 129 | # Convert speed to percentage change 130 | percentage_change = (speed - 1) * 100 131 | 132 | # Format with a leading "+" or "-" as required 133 | return f"{percentage_change:+.0f}%" 134 | -------------------------------------------------------------------------------- /weclone-audio/src/server未完工/utils.py: -------------------------------------------------------------------------------- 1 | # utils.py 2 | 3 | from flask import request, jsonify 4 | from functools import wraps 5 | import os 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | def getenv_bool(name: str, default: bool = False) -> bool: 11 | return os.getenv(name, str(default)).lower() in ("yes", "y", "true", "1", "t") 12 | 13 | API_KEY = os.getenv('API_KEY', 'your_api_key_here') 14 | REQUIRE_API_KEY = getenv_bool('REQUIRE_API_KEY', True) 15 | 16 | def require_api_key(f): 17 | @wraps(f) 18 | def decorated_function(*args, **kwargs): 19 | if not REQUIRE_API_KEY: 20 | return f(*args, **kwargs) 21 | auth_header = request.headers.get('Authorization') 22 | if not auth_header or not auth_header.startswith('Bearer '): 23 | return jsonify({"error": "Missing or invalid API key"}), 401 24 | token = auth_header.split('Bearer ')[1] 25 | if token != API_KEY: 26 | return jsonify({"error": "Invalid API key"}), 401 27 | return f(*args, **kwargs) 28 | return decorated_function 29 | 30 | # Mapping of audio format to MIME type 31 | AUDIO_FORMAT_MIME_TYPES = { 32 | "mp3": "audio/mpeg", 33 | "opus": "audio/ogg", 34 | "aac": "audio/aac", 35 | "flac": "audio/flac", 36 | "wav": "audio/wav", 37 | "pcm": "audio/L16" 38 | } 39 | -------------------------------------------------------------------------------- /weclone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/__init__.py -------------------------------------------------------------------------------- /weclone/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | import commentjson 3 | from pathlib import Path 4 | import os 5 | import sys 6 | import functools 7 | 8 | from weclone.utils.log import logger, capture_output 9 | from weclone.utils.config import load_config 10 | 11 | cli_config: dict | None = None 12 | 13 | try: 14 | import tomllib # type: ignore Python 3.11+ 15 | except ImportError: 16 | import tomli as tomllib 17 | 18 | 19 | def clear_argv(func): 20 | """ 21 | 装饰器:在调用被装饰函数前,清理 sys.argv,只保留脚本名。调用后恢复原始 sys.argv。 22 | 用于防止参数被 Hugging Face HfArgumentParser 解析造成 ValueError。 23 | """ 24 | 25 | @functools.wraps(func) 26 | def wrapper(*args, **kwargs): 27 | original_argv = sys.argv.copy() 28 | sys.argv = [original_argv[0]] # 只保留脚本名 29 | try: 30 | return func(*args, **kwargs) 31 | finally: 32 | sys.argv = original_argv # 恢复原始 sys.argv 33 | 34 | return wrapper 35 | 36 | 37 | def apply_common_decorators(capture_output_enabled=False): 38 | """ 39 | A unified decorator for applications 40 | """ 41 | 42 | def decorator(original_cmd_func): 43 | @functools.wraps(original_cmd_func) 44 | def new_runtime_wrapper(*args, **kwargs): 45 | if cli_config and cli_config.get("full_log", False): 46 | return capture_output(original_cmd_func)(*args, **kwargs) 47 | else: 48 | return original_cmd_func(*args, **kwargs) 49 | 50 | func_with_clear_argv = clear_argv(new_runtime_wrapper) 51 | 52 | return functools.wraps(original_cmd_func)(func_with_clear_argv) 53 | 54 | return decorator 55 | 56 | 57 | @click.group() 58 | def cli(): 59 | """WeClone: 从聊天记录创造数字分身的一站式解决方案""" 60 | _check_project_root() 61 | _check_versions() 62 | global cli_config 63 | cli_config = load_config(arg_type="cli_args") 64 | 65 | 66 | @cli.command("make-dataset", help="处理聊天记录CSV文件,生成问答对数据集。") 67 | @apply_common_decorators() 68 | def qa_generator(): 69 | """处理聊天记录CSV文件,生成问答对数据集。""" 70 | from weclone.data.qa_generator import DataProcessor 71 | 72 | processor = DataProcessor() 73 | processor.main() 74 | 75 | 76 | @cli.command("train-sft", help="使用准备好的数据集对模型进行微调。") 77 | @apply_common_decorators() 78 | def train_sft(): 79 | """使用准备好的数据集对模型进行微调。""" 80 | from weclone.train.train_sft import main as train_sft_main 81 | 82 | train_sft_main() 83 | 84 | 85 | @cli.command("webchat-demo", help="启动 Web UI 与微调后的模型进行交互测试。") # 命令名修改为 web-demo 86 | @apply_common_decorators() 87 | def web_demo(): 88 | """启动 Web UI 与微调后的模型进行交互测试。""" 89 | from weclone.eval.web_demo import main as web_demo_main 90 | 91 | web_demo_main() 92 | 93 | 94 | # TODO 添加评估功能 @cli.command("eval-model", help="使用从训练数据中划分出来的验证集评估。") 95 | @apply_common_decorators() 96 | def eval_model(): 97 | """使用从训练数据中划分出来的验证集评估。""" 98 | from weclone.eval.eval_model import main as evaluate_main 99 | 100 | evaluate_main() 101 | 102 | 103 | @cli.command("test-model", help="使用常见聊天问题测试模型。") 104 | @apply_common_decorators() 105 | def test_model(): 106 | """测试""" 107 | from weclone.eval.test_model import main as test_main 108 | 109 | test_main() 110 | 111 | 112 | @cli.command("server", help="启动API服务,提供模型推理接口。") 113 | @apply_common_decorators() 114 | def server(): 115 | """启动API服务,提供模型推理接口。""" 116 | from weclone.server.api_service import main as server_main 117 | 118 | server_main() 119 | 120 | 121 | def _check_project_root(): 122 | """检查当前目录是否为项目根目录,并验证项目名称。""" 123 | project_root_marker = "pyproject.toml" 124 | current_dir = Path(os.getcwd()) 125 | pyproject_path = current_dir / project_root_marker 126 | 127 | if not pyproject_path.is_file(): 128 | logger.error(f"未在当前目录找到 {project_root_marker} 文件。") 129 | logger.error("请确保在WeClone项目根目录下运行此命令。") 130 | sys.exit(1) 131 | 132 | try: 133 | with open(pyproject_path, "rb") as f: 134 | pyproject_data = tomllib.load(f) 135 | project_name = pyproject_data.get("project", {}).get("name") 136 | if project_name != "WeClone": 137 | logger.error("请确保在正确的 WeClone 项目根目录下运行。") 138 | sys.exit(1) 139 | except tomllib.TOMLDecodeError as e: 140 | logger.error(f"错误:无法解析 {pyproject_path} 文件: {e}") 141 | sys.exit(1) 142 | except Exception as e: 143 | logger.error(f"读取或处理 {pyproject_path} 时发生意外错误: {e}") 144 | sys.exit(1) 145 | 146 | 147 | def _check_versions(): 148 | """比较本地 settings.jsonc 版本和 pyproject.toml 中的配置文件指南版本""" 149 | if tomllib is None: # Skip check if toml parser failed to import 150 | return 151 | 152 | ROOT_DIR = Path(__file__).parent.parent 153 | SETTINGS_PATH = ROOT_DIR / "settings.jsonc" 154 | PYPROJECT_PATH = ROOT_DIR / "pyproject.toml" 155 | 156 | settings_version = None 157 | config_guide_version = None 158 | config_changelog = None 159 | 160 | if SETTINGS_PATH.exists(): 161 | try: 162 | with open(SETTINGS_PATH, "r", encoding="utf-8") as f: 163 | settings_data = commentjson.load(f) 164 | settings_version = settings_data.get("version") 165 | except Exception as e: 166 | logger.error(f"错误:无法读取或解析 {SETTINGS_PATH}: {e}") 167 | logger.error("请确保 settings.jsonc 文件存在且格式正确。") 168 | sys.exit(1) 169 | else: 170 | logger.error(f"错误:未找到配置文件 {SETTINGS_PATH}。") 171 | logger.error("请确保 settings.jsonc 文件位于项目根目录。") 172 | sys.exit(1) 173 | 174 | if PYPROJECT_PATH.exists(): 175 | try: 176 | with open(PYPROJECT_PATH, "rb") as f: # tomllib 需要二进制模式 177 | pyproject_data = tomllib.load(f) 178 | weclone_tool_data = pyproject_data.get("tool", {}).get("weclone", {}) 179 | config_guide_version = weclone_tool_data.get("config_version") 180 | config_changelog = weclone_tool_data.get("config_changelog", "N/A") 181 | except Exception as e: 182 | logger.warning(f"警告:无法读取或解析 {PYPROJECT_PATH}: {e}。无法检查配置文件是否为最新。") 183 | else: 184 | logger.warning(f"警告:未找到文件 {PYPROJECT_PATH}。无法检查配置文件是否为最新。") 185 | 186 | if not settings_version: 187 | logger.error(f"错误:在 {SETTINGS_PATH} 中未找到 'version' 字段。") 188 | logger.error("请从 settings.template.json 复制或更新您的 settings.jsonc 文件。") 189 | sys.exit(1) 190 | 191 | if config_guide_version: 192 | if settings_version != config_guide_version: 193 | logger.warning( 194 | f"警告:您的 settings.jsonc 文件版本 ({settings_version}) 与项目建议的配置版本 ({config_guide_version}) 不一致。" 195 | ) 196 | logger.warning("这可能导致意外行为或错误。请从 settings.template.json 复制或更新您的 settings.jsonc 文件。") 197 | # TODO 根据版本号打印更新日志 198 | logger.warning(f"配置文件更新日志:\n{config_changelog}") 199 | elif PYPROJECT_PATH.exists(): # 如果文件存在但未读到版本 200 | logger.warning( 201 | f"警告:在 {PYPROJECT_PATH} 的 [tool.weclone] 下未找到 'config_version' 字段。" 202 | "无法确认您的 settings.jsonc 是否为最新配置版本。" 203 | ) 204 | 205 | 206 | if __name__ == "__main__": 207 | cli() 208 | -------------------------------------------------------------------------------- /weclone/core/inference/offline_infer.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List, Optional, Union 3 | 4 | 5 | from llamafactory.data import get_dataset, get_template_and_fix_tokenizer 6 | from llamafactory.extras.constants import IGNORE_INDEX 7 | from llamafactory.extras.misc import get_device_count 8 | from llamafactory.extras.packages import is_vllm_available 9 | from llamafactory.hparams import get_infer_args 10 | from llamafactory.model import load_tokenizer 11 | from pydantic import BaseModel 12 | from vllm.sampling_params import GuidedDecodingParams 13 | 14 | 15 | from vllm import LLM, SamplingParams 16 | from vllm.lora.request import LoRARequest 17 | 18 | 19 | # 这里不需要写太好,transforms库后续更新自带vllm 20 | 21 | 22 | def vllm_infer( 23 | inputs: Union[str, List[str]], 24 | model_name_or_path: str, 25 | adapter_name_or_path: Optional[str] = None, 26 | dataset: str = "alpaca_en_demo", 27 | dataset_dir: str = "data", 28 | template: str = "default", 29 | cutoff_len: int = 2048, 30 | max_samples: Optional[int] = None, 31 | vllm_config: str = "{}", 32 | save_name: str = "generated_predictions.jsonl", 33 | temperature: float = 0.95, 34 | top_p: float = 0.7, 35 | top_k: int = 50, 36 | guided_decoding_class: Optional[type[BaseModel]] = None, 37 | bad_words: Optional[List[str]] = None, 38 | logprobs: Optional[int] = None, 39 | max_new_tokens: int = 1024, 40 | repetition_penalty: float = 1.0, 41 | skip_special_tokens: bool = True, 42 | seed: Optional[int] = None, 43 | pipeline_parallel_size: int = 1, 44 | image_max_pixels: int = 768 * 768, 45 | image_min_pixels: int = 32 * 32, 46 | ): 47 | r"""Perform batch generation using vLLM engine, which supports tensor parallelism.""" 48 | if pipeline_parallel_size > get_device_count(): 49 | raise ValueError("Pipeline parallel size should be smaller than the number of gpus.") 50 | 51 | model_args, data_args, _, generating_args = get_infer_args( 52 | dict( 53 | model_name_or_path=model_name_or_path, 54 | adapter_name_or_path=adapter_name_or_path, 55 | dataset=dataset, 56 | dataset_dir=dataset_dir, 57 | template=template, 58 | cutoff_len=cutoff_len, 59 | max_samples=max_samples, 60 | preprocessing_num_workers=16, 61 | vllm_config=vllm_config, 62 | temperature=temperature, 63 | top_p=top_p, 64 | top_k=top_k, 65 | max_new_tokens=max_new_tokens, 66 | repetition_penalty=repetition_penalty, 67 | ) 68 | ) 69 | 70 | tokenizer_module = load_tokenizer(model_args) 71 | tokenizer = tokenizer_module["tokenizer"] 72 | template_obj = get_template_and_fix_tokenizer(tokenizer, data_args) 73 | template_obj.mm_plugin.expand_mm_tokens = False # for vllm generate 74 | 75 | if guided_decoding_class: 76 | json_schema = guided_decoding_class.model_json_schema() 77 | guided_decoding_params = GuidedDecodingParams(json=json_schema) 78 | else: 79 | guided_decoding_params = None 80 | 81 | sampling_params = SamplingParams( 82 | repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0 83 | temperature=generating_args.temperature, 84 | top_p=generating_args.top_p or 1.0, # top_p must > 0 85 | top_k=generating_args.top_k or -1, # top_k must > 0 86 | stop_token_ids=template_obj.get_stop_token_ids(tokenizer), 87 | max_tokens=generating_args.max_new_tokens, 88 | skip_special_tokens=skip_special_tokens, 89 | seed=seed, 90 | guided_decoding=guided_decoding_params, 91 | bad_words=bad_words, 92 | ) 93 | if model_args.adapter_name_or_path is not None: 94 | lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0]) 95 | else: 96 | lora_request = None 97 | 98 | engine_args = { 99 | "model": model_args.model_name_or_path, 100 | "trust_remote_code": True, 101 | "dtype": model_args.infer_dtype, 102 | "max_model_len": cutoff_len + max_new_tokens, 103 | # "tensor_parallel_size": 1, 104 | # "pipeline_parallel_size": pipeline_parallel_size, 105 | # "data_parallel_size": get_device_count(), // vllm0.8.5版本支持DP 106 | "disable_log_stats": True, 107 | "enable_lora": model_args.adapter_name_or_path is not None, 108 | "enable_prefix_caching": True, # 是否启用前缀缓存 109 | "gpu_memory_utilization": 0.95, 110 | # "quantization": "bitsandbytes", # 是否启用vllm的 bitsandbytes 的量化加载 111 | # "load_format": "bitsandbytes", 112 | } 113 | if template_obj.mm_plugin.__class__.__name__ != "BasePlugin": 114 | engine_args["limit_mm_per_prompt"] = {"image": 4, "video": 2, "audio": 2} 115 | 116 | if isinstance(model_args.vllm_config, dict): 117 | engine_args.update(model_args.vllm_config) 118 | 119 | results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request) 120 | return results 121 | -------------------------------------------------------------------------------- /weclone/core/inference/online_infer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import requests 4 | from openai import OpenAI 5 | 6 | class OnlineLLM: 7 | def __init__(self, api_key: str, base_url: str,model_name: str,default_system: str): 8 | self.api_key = api_key 9 | self.base_url = base_url 10 | self.model_name = model_name 11 | self.default_system = default_system 12 | self.client = OpenAI( 13 | api_key=self.api_key, 14 | base_url=self.base_url 15 | ) 16 | 17 | 18 | def chat(self,prompt_text, 19 | temperature: float = 0.7, 20 | max_tokens: int = 1024, 21 | top_p: float = 0.95, 22 | stream: bool = False, 23 | enable_thinking: bool = False): 24 | messages = [ 25 | {"role": "system", "content": self.default_system}, 26 | {"role": "user", "content": prompt_text}, 27 | ] 28 | response = self.client.chat.completions.create( 29 | model=self.model_name, 30 | messages=messages, 31 | stream=stream, 32 | temperature = temperature, 33 | max_tokens=max_tokens, 34 | top_p=top_p, 35 | # enable_thinking=enable_thinking 适配Qwen3动态开启推理 36 | 37 | ) 38 | 39 | return response 40 | 41 | -------------------------------------------------------------------------------- /weclone/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/data/__init__.py -------------------------------------------------------------------------------- /weclone/data/chat_parsers/wechat_parser.py: -------------------------------------------------------------------------------- 1 | class WeChatParser: 2 | def decrypt_wechat_image(self, encrypted_path, output_path): 3 | """解密微信加密的图片文件""" 4 | pass 5 | 6 | def parse_chat_records(self, db_path): 7 | """解析聊天记录数据库""" 8 | pass 9 | -------------------------------------------------------------------------------- /weclone/data/clean/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/data/clean/__init__.py -------------------------------------------------------------------------------- /weclone/data/clean/get_score.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | # TODO 未使用 5 | def adjust_score_tiered( 6 | initial_score: int, probabilities: list[float], thresholds: list[float], downgrade_levels: list[int] 7 | ) -> int: 8 | """ 9 | 根据大模型给出评分时的概率,对原始评分进行分级置信度调整。 10 | 11 | Args: 12 | initial_score: 大模型给出的原始评分 (整数 1 到 5)。 13 | probabilities: 包含 5 个评分 (1 到 5) 概率的列表。 14 | 例如 [P(1), P(2), P(3), P(4), P(5)]。 15 | thresholds: 一个降序排列的概率阈值列表,定义置信度区间边界。 16 | 例如 [0.6, 0.3]。 17 | downgrade_levels: 与 thresholds 对应的降级幅度列表,长度比 thresholds 多 1。 18 | 定义了每个置信度区间的降级数。例如 [0, 1, 2]。 19 | 20 | Returns: 21 | 经过置信度调整后的最终评分 (整数 1 到 5)。 22 | 23 | Raises: 24 | ValueError: 如果输入参数不合法(例如概率列表长度不对,阈值未降序等)。 25 | """ 26 | # --- 输入校验 --- 27 | if not (1 <= initial_score <= 5): 28 | raise ValueError("initial_score 必须在 1 到 5 之间。") 29 | if len(probabilities) != 5: 30 | raise ValueError("probabilities 列表必须包含 5 个元素。") 31 | # 检查概率和是否接近 1 (允许小的浮点误差) 32 | if not math.isclose(sum(probabilities), 1.0, abs_tol=1e-6): 33 | print(f"警告: 概率之和 {sum(probabilities)} 不接近 1.0。请检查概率来源。") # 打印警告而非直接报错 34 | # raise ValueError("probabilities 中元素的和必须接近 1.0。") 35 | if len(downgrade_levels) != len(thresholds) + 1: 36 | raise ValueError("downgrade_levels 的长度必须比 thresholds 的长度多 1。") 37 | if any(thresholds[i] < thresholds[i + 1] for i in range(len(thresholds) - 1)): 38 | raise ValueError("thresholds 列表必须是降序排列的。") 39 | if any(level < 0 for level in downgrade_levels): 40 | raise ValueError("downgrade_levels 中的降级幅度不能为负数。") 41 | 42 | # --- 算法核心 --- 43 | # 1. 获取选中分数的概率 44 | # 列表索引从0开始,所以评分 s 对应的索引是 s-1 45 | try: 46 | p_chosen = probabilities[initial_score - 1] 47 | except IndexError: 48 | # 这个错误理论上不应发生,因为 initial_score 已校验在 1-5 之间 49 | raise ValueError(f"无法从 probabilities 列表获取索引 {initial_score - 1} 的值。") 50 | 51 | # 2. 确定降级幅度 52 | downgrade = downgrade_levels[-1] # 默认为最低置信度区间的降级幅度 53 | # 遍历阈值列表 (从高到低) 54 | for i in range(len(thresholds)): 55 | if p_chosen >= thresholds[i]: 56 | downgrade = downgrade_levels[i] # 找到对应的置信度区间 57 | break # 停止遍历 58 | 59 | # 3. 计算调整后的评分 60 | preliminary_score = initial_score - downgrade 61 | adjusted_score = max(1, preliminary_score) # 确保分数不低于 1 62 | 63 | # 4. 返回结果 64 | return adjusted_score 65 | -------------------------------------------------------------------------------- /weclone/data/clean/strategies.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pandas as pd 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from typing import Any, Dict, List, Union 6 | from langchain_core.prompts import PromptTemplate 7 | from weclone.data.models import QaPair, CutMessage, QaPairScore 8 | from weclone.prompts.clean_data import CLEAN_PROMPT 9 | import os 10 | from weclone.utils.log import logger 11 | 12 | 13 | @dataclass 14 | class CleaningStrategy(ABC): 15 | """数据清洗策略的抽象基类""" 16 | 17 | make_dataset_config: Dict 18 | 19 | @abstractmethod 20 | def clean(self, data: Any) -> Any: 21 | """ 22 | 执行数据清洗操作。 23 | 24 | Args: 25 | data: 需要清洗的数据。 26 | 27 | Returns: 28 | 清洗后的数据。 29 | """ 30 | pass 31 | 32 | 33 | @dataclass 34 | class LLMCleaningStrategy(CleaningStrategy): 35 | """使用大模型进行数据清洗的策略""" 36 | 37 | 38 | def judge(self, data: List[QaPair]) -> None: 39 | """ 40 | 调用llm打分,并将分数直接赋值给传入的QaPair。 41 | """ 42 | from weclone.core.inference.offline_infer import vllm_infer 43 | logger.info("开始使用llm对数据打分") 44 | inputs = [] 45 | prompt_template = PromptTemplate.from_template(CLEAN_PROMPT) 46 | for qa in data: 47 | inputs.append(prompt_template.invoke({"id": qa.id, "Q": qa.instruction, "A": qa.output}).text) # type: ignore 48 | outputs = vllm_infer( 49 | inputs, 50 | self.make_dataset_config["model_name_or_path"], 51 | template=self.make_dataset_config["template"], 52 | temperature=0, 53 | guided_decoding_class=QaPairScore, 54 | repetition_penalty=1.2, 55 | bad_words=[r"\n"], 56 | ) 57 | 58 | parsed_scores: List[QaPairScore] = [] 59 | for result in outputs: 60 | try: 61 | score_data = json.loads(result.outputs[0].text) 62 | qa_score = QaPairScore(**score_data) 63 | parsed_scores.append(qa_score) 64 | except json.JSONDecodeError: 65 | logger.error(f"Error decoding JSON: {result.outputs[0].text}") 66 | 67 | score_map = {score.id: score.score for score in parsed_scores} 68 | for qa in data: 69 | if qa.id in score_map: 70 | qa.score = score_map[qa.id] 71 | else: 72 | logger.warning(f"Warning: Score not found for QaPair with id {qa.id}. Assigning default score.") 73 | 74 | scores = [qa.score for qa in data if qa.score is not None] 75 | score_series = pd.Series(scores) 76 | score_counts = score_series.value_counts().sort_index() 77 | score_percentages = score_series.value_counts(normalize=True).sort_index() * 100 78 | pd.set_option("display.unicode.east_asian_width", True) # 尝试修正对齐问题 79 | distribution_df = pd.DataFrame( # 合并数量和百分比到一个 DataFrame 中以便打印 80 | { 81 | "数量": score_counts, 82 | "占比(%)": score_percentages.round(2), 83 | } 84 | ) 85 | distribution_df.index.name = "分数" # 给第一列加上列名:分数 86 | printable_df_str = distribution_df.reset_index().to_string(index=False) 87 | logger.success(f"llm打分分数分布情况:\n{printable_df_str}") 88 | 89 | def clean(self) -> str: 90 | """ 91 | 清洗 SFT 数据并返回清洗后的文件路径。 92 | 如果未启用清洗,则返回原始路径。 93 | """ 94 | config = self.make_dataset_config 95 | dataset_dir = config["dataset_dir"] 96 | dataset_info_path = os.path.join(dataset_dir, "dataset_info.json") 97 | 98 | sft_json_path = os.path.join(dataset_dir, "sft-my.json") 99 | output_json_path = os.path.join(dataset_dir, "sft-my-l.json") 100 | accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1) 101 | 102 | if not config.get("clean_dataset", {}).get("enable_clean"): 103 | logger.info("未启用清洗功能") 104 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") 105 | return sft_json_path 106 | 107 | try: 108 | with open(sft_json_path, 'r', encoding='utf-8') as f: 109 | data = json.load(f) 110 | filtered_data = [item for item in data if item.get("score", 0) >= accept_score] 111 | 112 | with open(output_json_path, 'w', encoding='utf-8') as f: 113 | json.dump(filtered_data, f, ensure_ascii=False, indent=4) 114 | 115 | logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据") 116 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json") 117 | return output_json_path 118 | 119 | except Exception as e: 120 | logger.error(f"清洗数据失败,使用原始数据: {str(e)}") 121 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") 122 | return sft_json_path 123 | 124 | def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str): 125 | """ 126 | 修改 dataset_info.json 文件中的 file_name 字段 127 | """ 128 | try: 129 | with open(dataset_info_path, "r", encoding="utf-8") as f: 130 | dataset_info = json.load(f) 131 | 132 | # 更新所有支持的数据集的 file_name 133 | for key in ["wechat-sft", "wechat-sft-with-history"]: 134 | if key in dataset_info: 135 | dataset_info[key]["file_name"] = new_file_name 136 | 137 | # 写回文件 138 | with open(dataset_info_path, "w", encoding="utf-8") as f: 139 | json.dump(dataset_info, f, indent=4, ensure_ascii=False) 140 | 141 | logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}") 142 | 143 | except Exception as e: 144 | logger.warning(f"无法更新 dataset_info.json: {e}") 145 | -------------------------------------------------------------------------------- /weclone/data/clean/strategies_online.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import pandas as pd 4 | from tqdm import tqdm 5 | from abc import ABC, abstractmethod 6 | from dataclasses import dataclass 7 | from typing import Any, Dict, List 8 | from langchain_core.prompts import PromptTemplate 9 | from weclone.data.models import QaPair, QaPairScore 10 | from weclone.prompts.clean_data import CLEAN_PROMPT,ONLINE_LLM_CLEAN_PROMPT 11 | from weclone.core.inference.online_infer import OnlineLLM 12 | from weclone.utils.log import logger 13 | import os 14 | 15 | @dataclass 16 | class CleaningStrategy(ABC): 17 | """数据清洗策略的抽象基类""" 18 | 19 | make_dataset_config: Dict 20 | 21 | @abstractmethod 22 | def clean(self, data: Any) -> Any: 23 | pass 24 | 25 | @dataclass 26 | class OlineLLMCleaningStrategy(CleaningStrategy): 27 | """使用大模型进行数据清洗的策略""" 28 | 29 | def judge(self, data: List[QaPair]) -> None: 30 | logger.info("开始使用在线模型对数据打分") 31 | 32 | logger.info(f"使用模型 {self.make_dataset_config.get('model_name', '')}") 33 | 34 | client = OnlineLLM( 35 | api_key = self.make_dataset_config.get("llm_api_key"), 36 | base_url = self.make_dataset_config.get("base_url"), 37 | model_name = self.make_dataset_config.get("model_name"), 38 | default_system = self.make_dataset_config.get("default_system") 39 | ) 40 | prompt_template = PromptTemplate.from_template(ONLINE_LLM_CLEAN_PROMPT) 41 | 42 | parsed_scores = [] 43 | clean_batch_size = int(self.make_dataset_config.get("clean_batch_size", 10)) 44 | for i in tqdm(range(0, len(data), clean_batch_size), desc="在线模型评分进度"): 45 | batch = data[i : i + clean_batch_size] 46 | # 构造当前批次的 qa_list 47 | qa_list = [ 48 | {"id": qa.id, "Q": qa.instruction, "A": qa.output} 49 | for qa in batch 50 | ] 51 | qa_list_json = json.dumps(qa_list, ensure_ascii=False) 52 | # 填充模板 53 | prompt_text = prompt_template.invoke({ 54 | "qa_list": qa_list_json 55 | }).text 56 | try: 57 | response = client.chat(prompt_text) 58 | result_text = response.choices[0].message.content 59 | # print("大模型返回:",result_text) 60 | # 如果有 ,只保留 之后的内容 61 | if "" in result_text: 62 | result_text = result_text.split("", 1)[1] 63 | # 去掉开头和结尾的 ```json 或 ``` 等代码块标记 64 | result_text = re.sub(r"^```json\s*|```$", "", result_text.strip(), flags=re.MULTILINE) 65 | # 如果偶尔的几次解析失败就跳过 66 | try: 67 | score_list = json.loads(result_text) 68 | except json.JSONDecodeError as e: 69 | logger.error(f"JSON 解析失败,跳过本批次: {e}\n内容:{result_text}") 70 | continue 71 | 72 | for item in score_list: 73 | parsed_scores.append(QaPairScore(**item)) 74 | except Exception as e: 75 | ids_in_batch = [qa["id"] for qa in qa_list] 76 | logger.error(f"调用在线模型或解析结果失败,当前 batch QA ID 列表: {ids_in_batch},错误信息: {str(e)}") 77 | 78 | score_map = {score.id: score.score for score in parsed_scores} 79 | for qa in data: 80 | if qa.id in score_map: 81 | qa.score = score_map[qa.id] 82 | else: 83 | logger.warning(f"未获取到QA ID {qa.id}的分数,默认赋值0") 84 | qa.score = 0 85 | 86 | # 统计分数分布,打印日志(和本地版本保持一致) 87 | scores = [qa.score for qa in data if qa.score is not None] 88 | score_series = pd.Series(scores) 89 | score_counts = score_series.value_counts().sort_index() 90 | score_percentages = score_series.value_counts(normalize=True).sort_index() * 100 91 | pd.set_option("display.unicode.east_asian_width", True) 92 | distribution_df = pd.DataFrame({ 93 | "数量": score_counts, 94 | "占比(%)": score_percentages.round(2), 95 | }) 96 | distribution_df.index.name = "分数" 97 | printable_df_str = distribution_df.reset_index().to_string(index=False) 98 | logger.success(f"在线模型打分分数分布情况:\n{printable_df_str}") 99 | 100 | def clean(self) -> str: 101 | """ 102 | 清洗 SFT 数据并返回清洗后的文件路径。 103 | 如果未启用清洗,则返回原始路径。 104 | """ 105 | config = self.make_dataset_config 106 | dataset_dir = config["dataset_dir"] 107 | dataset_info_path = os.path.join(dataset_dir, "dataset_info.json") 108 | 109 | sft_json_path = os.path.join(dataset_dir, "sft-my.json") 110 | output_json_path = os.path.join(dataset_dir, "sft-my-l.json") 111 | accept_score = config.get("clean_dataset", {}).get("llm", {}).get("accept_score", 1) 112 | 113 | if not config.get("clean_dataset", {}).get("enable_clean"): 114 | logger.info("未启用清洗功能") 115 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") 116 | return sft_json_path 117 | 118 | try: 119 | with open(sft_json_path, 'r', encoding='utf-8') as f: 120 | data = json.load(f) 121 | filtered_data = [item for item in data if item.get("score", 0) >= accept_score] 122 | 123 | with open(output_json_path, 'w', encoding='utf-8') as f: 124 | json.dump(filtered_data, f, ensure_ascii=False, indent=4) 125 | 126 | logger.success(f"已筛出低于{accept_score}分的数据,共保留 {len(filtered_data)} 条数据") 127 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my-l.json") 128 | return output_json_path 129 | 130 | except Exception as e: 131 | logger.error(f"清洗数据失败,使用原始数据: {str(e)}") 132 | self._update_dataset_info_file(dataset_info_path, new_file_name="sft-my.json") 133 | return sft_json_path 134 | 135 | def _update_dataset_info_file(self, dataset_info_path: str, new_file_name: str): 136 | """ 137 | 修改 dataset_info.json 文件中的 file_name 字段 138 | """ 139 | try: 140 | with open(dataset_info_path, "r", encoding="utf-8") as f: 141 | dataset_info = json.load(f) 142 | 143 | # 更新所有支持的数据集的 file_name 144 | for key in ["wechat-sft", "wechat-sft-with-history"]: 145 | if key in dataset_info: 146 | dataset_info[key]["file_name"] = new_file_name 147 | 148 | # 写回文件 149 | with open(dataset_info_path, "w", encoding="utf-8") as f: 150 | json.dump(dataset_info, f, indent=4, ensure_ascii=False) 151 | 152 | logger.info(f"已更新 dataset_info.json 中的 file_name 为 {new_file_name}") 153 | 154 | except Exception as e: 155 | logger.warning(f"无法更新 dataset_info.json: {e}") 156 | -------------------------------------------------------------------------------- /weclone/data/models.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from pandas import Timestamp 3 | from pydantic import BaseModel 4 | 5 | 6 | @dataclass 7 | class ChatMessage: 8 | id: int 9 | MsgSvrID: int 10 | type_name: str 11 | is_sender: int 12 | talker: str 13 | room_name: str 14 | msg: str 15 | src: str 16 | CreateTime: Timestamp 17 | 18 | 19 | @dataclass 20 | class CutMessage: 21 | is_sender: int 22 | cut_type: str 23 | CreateTime: Timestamp 24 | 25 | 26 | @dataclass 27 | class QaPair: 28 | id: int 29 | system: str 30 | instruction: str 31 | output: str 32 | history: list[list[str]] 33 | time: Timestamp 34 | score: int 35 | 36 | 37 | class QaPairScore(BaseModel): 38 | id: int 39 | score: int 40 | 41 | 42 | skip_type_list = [ 43 | "添加好友", 44 | "推荐公众号", 45 | "动画表情", 46 | "位置", 47 | "文件", 48 | "位置共享", 49 | "接龙", 50 | "引用回复", 51 | "视频号直播或直播回放", 52 | "用户上传的GIF表情", 53 | "文件(猜)", 54 | "群公告", 55 | "视频号直播或直播回放等", 56 | "游戏相关", 57 | "转账", 58 | "赠送红包封面", 59 | "语音通话", 60 | "企业微信打招呼(猜)", 61 | "企业微信添加好友(猜)", 62 | "系统通知", 63 | "消息撤回1", 64 | "拍一拍", 65 | "消息撤回5", 66 | "消息撤回6", 67 | "消息撤回33", 68 | "消息撤回36", 69 | "消息撤回57", 70 | "邀请加群", 71 | "未知-11000,0", 72 | ] 73 | # 没处理的类型 74 | unprocessed_type_list = [] 75 | -------------------------------------------------------------------------------- /weclone/data/qa_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import subprocess 4 | from typing import Dict, List, Union 5 | import re 6 | 7 | import pandas as pd 8 | import json 9 | from pandas import Timestamp 10 | from llamafactory.extras.packages import is_vllm_available 11 | 12 | from weclone.data.clean.strategies import LLMCleaningStrategy 13 | from weclone.data.clean.strategies_online import OlineLLMCleaningStrategy 14 | from weclone.utils.config import load_config 15 | from weclone.utils.log import logger 16 | from weclone.data.models import ChatMessage, CutMessage, skip_type_list, QaPair 17 | from weclone.data.strategies import TimeWindowStrategy, LLMStrategy 18 | 19 | 20 | class DataProcessor: 21 | def __init__(self): 22 | self.config = load_config(arg_type="make_dataset") 23 | self.csv_folder = "./dataset/csv" 24 | self.system_prompt = self.config["default_system"] 25 | self.cut_type_list = [ 26 | "图片", 27 | "视频", 28 | "合并转发的聊天记录", 29 | "语音", 30 | "(分享)音乐", 31 | "(分享)卡片式链接", 32 | "(分享)笔记", 33 | "(分享)小程序", 34 | "(分享)收藏夹", 35 | "(分享)小说(猜)", 36 | "(分享)视频号名片", 37 | "(分享)视频号视频", 38 | "粘贴的文本", # 无法解析的分享链接 39 | ] 40 | 41 | # blocked_words 42 | config_blocked_words = self.config.get("blocked_words", []) 43 | file_blocked_words = [] 44 | try: 45 | with open("./dataset/blocked_words.json", encoding="utf-8") as f: 46 | file_blocked_words = json.load(f).get("blocked_words", []) 47 | except (FileNotFoundError, json.JSONDecodeError): 48 | pass 49 | 50 | self.blocked_words = list(set(config_blocked_words + file_blocked_words)) 51 | # logger.info(f"聊天记录禁用词: {self.blocked_words}") 52 | 53 | if self.config["single_combine_strategy"] == "time_window": 54 | self.single_combine_strategy = TimeWindowStrategy( 55 | time_window=self.config["single_combine_time_window"] * 60, 56 | is_single_chat=True, 57 | ) 58 | elif self.config["single_combine_strategy"] == "llm": 59 | self.single_combine_strategy = LLMStrategy( 60 | is_single_chat=True, 61 | ) 62 | 63 | if self.config["qa_match_strategy"] == "time_window": 64 | self.qa_match_strategy = TimeWindowStrategy( 65 | time_window=self.config["qa_match_time_window"] * 60, 66 | is_single_chat=False, 67 | ) 68 | elif self.config["qa_match_strategy"] == "llm": 69 | self.qa_match_strategy = LLMStrategy(is_single_chat=False) 70 | 71 | clean_dataset_config = self.config.get("clean_dataset", {}) 72 | enable_clean = clean_dataset_config.get("enable_clean", False) 73 | 74 | if enable_clean: 75 | if self.config.get("prompt_with_history", False): 76 | logger.warning("开启 prompt_with_history 不支持 clean_dataset 功能") 77 | exit() 78 | 79 | if not is_vllm_available() and not self.config.get("online_llm_clear"): 80 | logger.warning("vLLM 不可用,暂不清洗数据集。") 81 | clean_dataset_config["enable_clean"] = False 82 | 83 | if self.config.get("clean_dataset", {}).get("enable_clean", False): 84 | if self.config.get("clean_dataset", {}).get("clean_strategy", "llm") == "llm": 85 | if self.config.get("online_llm_clear"): 86 | self.clean_strategy = OlineLLMCleaningStrategy(make_dataset_config=self.config) 87 | else: 88 | self.clean_strategy = LLMCleaningStrategy(make_dataset_config=self.config) 89 | self.c = self.config 90 | 91 | def main(self): 92 | if not os.path.exists(self.csv_folder) or not os.listdir(self.csv_folder): 93 | logger.error(f"错误:目录 '{self.csv_folder}' 不存在或为空,请检查路径并确保其中包含 CSV 聊天数据文件。") 94 | return 95 | 96 | csv_files = self.get_csv_files() 97 | logger.info(f"共发现 {len(csv_files)} 个 CSV 文件,开始处理") 98 | message_list: List[ChatMessage] = [] 99 | for csv_file in csv_files: 100 | logger.debug(f"开始处理 CSV 文件: {csv_file}") 101 | chat_messages = self.load_csv(csv_file) 102 | message_list.extend(self.group_consecutive_messages(messages=chat_messages)) 103 | # self.process_by_msgtype(chat_message) 104 | logger.debug(f"处理完成: {csv_file},共加载 {len(chat_messages)} 条消息") 105 | qa_res = self.match_qa(message_list) 106 | if self.c["prompt_with_history"]: 107 | qa_res = self.add_history_to_qa(qa_res) 108 | else: 109 | qa_res = [item for item in qa_res if isinstance(item, QaPair)] 110 | 111 | if self.c.get("clean_dataset", {}).get("enable_clean", False): 112 | self.clean_strategy.judge(qa_res) 113 | # qa_res = self.clean_strategy.clean(qa_res) 114 | self.save_result(qa_res) 115 | self._execute_length_cdf_script() 116 | 117 | logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 ./dataset/res_csv/sft/sft-my.json") 118 | 119 | def _execute_length_cdf_script(self): 120 | """执行 length_cdf.py 脚本来计算cutoff_len。""" 121 | try: 122 | python_executable = sys.executable 123 | # 脚本路径是相对于项目根目录的 124 | script_path = os.path.join("weclone", "utils", "length_cdf.py") 125 | 126 | command_parts = [ 127 | python_executable, 128 | script_path, 129 | f'--model_name_or_path="{self.c["model_name_or_path"]}"', 130 | f'--dataset="{self.c["dataset"]}"', 131 | f'--dataset_dir="{self.c["dataset_dir"]}"', 132 | f'--template="{self.c["template"]}"', 133 | f"--interval={self.c['cutoff_len']}", 134 | ] 135 | 136 | child_env = os.environ.copy() 137 | child_env["CUDA_VISIBLE_DEVICES"] = "0" 138 | child_env["LLAMAFACTORY_VERBOSITY"] = "ERROR" 139 | 140 | process = subprocess.Popen( 141 | command_parts, 142 | env=child_env, 143 | stdout=None, # 使用 None 表示使用父进程的标准输出(即终端) 144 | stderr=None, # 使用 None 表示使用父进程的标准错误(即终端) 145 | text=True, 146 | bufsize=1, # 行缓冲 147 | ) 148 | return_code = process.wait() 149 | if return_code != 0: 150 | logger.error(f"命令 '{' '.join(command_parts)}' 执行失败,返回码 {return_code}") 151 | except FileNotFoundError: 152 | # command_parts[0] 是 python_executable, command_parts[1] 是 script_path 153 | logger.error(f"命令执行失败: 找不到可执行文件 '{command_parts[0]}' 或脚本 '{command_parts[1]}'") 154 | except KeyError as e: 155 | logger.error(f"执行 length_cdf.py 脚本失败:配置项缺失 {str(e)}") 156 | except Exception as e: 157 | logger.error(f"执行 length_cdf.py 脚本时发生未知错误: {str(e)}") 158 | 159 | def get_csv_files(self): 160 | """遍历文件夹获取所有CSV文件路径,并按文件名中的起始序号排序""" 161 | 162 | csv_files = [] 163 | for chat_obj_folder in os.listdir(self.csv_folder): 164 | chat_obj_folder_path = os.path.join(self.csv_folder, chat_obj_folder) 165 | for csvfile in os.listdir(chat_obj_folder_path): 166 | if not csvfile.endswith(".csv"): 167 | continue 168 | csvfile_path = os.path.join(chat_obj_folder_path, csvfile) 169 | csv_files.append(csvfile_path) 170 | # 提取文件名中的起始数字,比如 wxid_..._0_5000.csv → 0 171 | pattern = re.compile(r"_(\d+)_\d+\.csv$") 172 | 173 | def extract_start(fp: str) -> int: 174 | name = os.path.basename(fp) 175 | m = pattern.search(name) 176 | return int(m.group(1)) if m else 0 177 | 178 | # 按起始数字升序排序 179 | csv_files.sort(key=extract_start) 180 | return csv_files 181 | 182 | def match_qa(self, messages: List[ChatMessage]) -> List[Union[QaPair, CutMessage]]: 183 | """ 184 | 匹配问答对 185 | 186 | Args: 187 | messages: 消息列表 188 | 189 | Returns: 190 | List[Union[QaPair, CutMessage]]: 包含指令和输出的问答对列表 191 | """ 192 | # 状态定义 193 | WAITING_INSTRUCTION = "waiting_instruction" # 等待指令 194 | WAITING_RESPONSE = "waiting_response" # 等待回复 195 | 196 | current_state = WAITING_INSTRUCTION 197 | qa_res: List[Union[QaPair, CutMessage]] = [] 198 | last_message = None 199 | current_instruction = None 200 | qa_id_counter = 0 201 | 202 | for msg in messages: 203 | if isinstance(msg, CutMessage): 204 | current_state = WAITING_INSTRUCTION 205 | current_instruction = None 206 | last_message = None 207 | if self.c["prompt_with_history"]: 208 | qa_res.append(msg) 209 | continue 210 | 211 | if current_state == WAITING_INSTRUCTION: 212 | if msg.is_sender == 0: # 收到对方消息 213 | current_instruction = msg.msg 214 | last_message = msg 215 | current_state = WAITING_RESPONSE 216 | 217 | elif current_state == WAITING_RESPONSE: 218 | if msg.is_sender == 0: # 收到对方消息 219 | current_instruction = msg.msg 220 | last_message = msg 221 | # 状态保持不变 222 | else: # 自己的回复 使用策略判断是否属于同一对话 223 | if last_message and self.qa_match_strategy.is_same_conversation([last_message], msg): 224 | assert current_instruction is not None, ( 225 | "current_instruction should not be None when creating a QA pair" 226 | ) 227 | qa_pair = QaPair( 228 | id=qa_id_counter, 229 | system=self.system_prompt, 230 | instruction=current_instruction, 231 | output=msg.msg, 232 | history=[], # No history in this context yet 233 | time=msg.CreateTime, # Use the response message time 234 | score=0, # Default score 235 | ) 236 | qa_res.append(qa_pair) 237 | qa_id_counter += 1 # 增加计数器 238 | else: 239 | if self.c["prompt_with_history"]: 240 | qa_res.append( 241 | CutMessage( 242 | is_sender=msg.is_sender, 243 | cut_type=msg.type_name, 244 | CreateTime=msg.CreateTime, 245 | ) 246 | ) 247 | # 无论是否匹配,都重置状态 248 | current_state = WAITING_INSTRUCTION 249 | current_instruction = None 250 | last_message = None 251 | 252 | return qa_res 253 | 254 | # TODO: need review 255 | def add_history_to_qa(self, qa_res: List[Union[QaPair, CutMessage]]) -> List[QaPair]: 256 | """ 257 | Adds conversation history to QaPair objects. 258 | 259 | Args: 260 | qa_res: A list containing QaPair and CutMessage objects. 261 | 262 | Returns: 263 | A list of QaPair objects with history populated. 264 | """ 265 | qa_res_with_history: List[QaPair] = [] 266 | current_history: List[List[str]] = [] 267 | last_timestamp: Timestamp = None # type: ignore 268 | 269 | for item in qa_res: 270 | if isinstance(item, CutMessage): 271 | if current_history: 272 | instruction = current_history[-1][0] 273 | output = current_history[-1][1] 274 | history = current_history[:-1] 275 | qa_pair_with_history = QaPair( 276 | id=-1, 277 | system=self.system_prompt, 278 | instruction=instruction, 279 | output=output, 280 | history=history, 281 | time=last_timestamp, 282 | score=0, 283 | ) 284 | qa_res_with_history.append(qa_pair_with_history) 285 | current_history = [] 286 | last_timestamp = None # type: ignore 287 | elif isinstance(item, QaPair): 288 | current_history.append([item.instruction, item.output]) 289 | last_timestamp = item.time 290 | 291 | if current_history: 292 | instruction = current_history[-1][0] 293 | output = current_history[-1][1] 294 | history = current_history[:-1] 295 | # Ensure last_timestamp is not None before assignment 296 | final_timestamp_end = last_timestamp 297 | assert final_timestamp_end is not None, "Timestamp cannot be None for the final QaPair" 298 | qa_pair_with_history = QaPair( 299 | id=-1, 300 | system=self.system_prompt, 301 | instruction=instruction, 302 | output=output, 303 | history=history, 304 | time=final_timestamp_end, 305 | score=0, 306 | ) 307 | qa_res_with_history.append(qa_pair_with_history) 308 | 309 | return qa_res_with_history 310 | 311 | def group_consecutive_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: 312 | """ 313 | 将同一个人连续发送的多条消息组合成一条消息,遇到cut_type添加cut 314 | 315 | Args: 316 | messages: 消息列表 317 | 318 | Returns: 319 | List[ChatMessage]: 组合后的消息列表 320 | """ 321 | if not messages: 322 | return [] 323 | 324 | def _combine_text(messages: List[ChatMessage]) -> ChatMessage: 325 | """ 326 | 合并多条消息为一条 327 | 328 | Args: 329 | messages: 要合并的消息列表 330 | 331 | Returns: 332 | ChatMessage: 合并后的消息 333 | """ 334 | base_msg = messages[0] 335 | combined_content = messages[0].msg 336 | 337 | for i in messages[1:]: 338 | content = i.msg 339 | if not content: 340 | continue 341 | 342 | if combined_content and combined_content[-1] not in ["。", "!", "?", "…", ",", "."]: 343 | combined_content += "," 344 | 345 | combined_content += content 346 | if len(combined_content) > self.c["combine_msg_max_length"]: 347 | logger.warning( 348 | f"组合后消息长度超过{self.c['combine_msg_max_length']}将截断:\n {combined_content[:50]}" 349 | ) 350 | combined_content = combined_content[: self.c["combine_msg_max_length"]] 351 | 352 | combined_message = ChatMessage( 353 | id=base_msg.id, 354 | MsgSvrID=base_msg.MsgSvrID, 355 | type_name=base_msg.type_name, 356 | is_sender=base_msg.is_sender, 357 | talker=base_msg.talker, 358 | room_name=base_msg.room_name, 359 | msg=combined_content, 360 | src=base_msg.src, 361 | CreateTime=messages[-1].CreateTime, # 使用最后一条消息的时间 362 | ) 363 | 364 | return combined_message 365 | 366 | def _create_cut_message(message: ChatMessage) -> CutMessage: 367 | return CutMessage( 368 | is_sender=message.is_sender, 369 | cut_type=message.type_name, 370 | CreateTime=message.CreateTime, 371 | ) 372 | 373 | def _combine_current_group(group): 374 | """ 375 | 处理当前消息组并添加到grouped_messages 376 | 377 | Args: 378 | group: 当前消息组 379 | """ 380 | if len(group) > 1: 381 | combined_msg = _combine_text(group) 382 | grouped_messages.append(combined_msg) 383 | else: 384 | grouped_messages.append(group[0]) 385 | 386 | grouped_messages = [] 387 | current_group = [] 388 | 389 | for _, current_msg in enumerate(messages): 390 | if current_msg.type_name in self.cut_type_list: 391 | if current_group: 392 | # 当前组有消息,合并当前组,并添加一条cut 393 | _combine_current_group(current_group) 394 | current_group = [] 395 | 396 | cut_msg = _create_cut_message(current_msg) 397 | grouped_messages.append(cut_msg) 398 | else: 399 | # 当前组没消息,检查上一个组 400 | if grouped_messages: 401 | if not isinstance(grouped_messages[-1], CutMessage): 402 | cut_msg = _create_cut_message(current_msg) 403 | grouped_messages.append(cut_msg) 404 | # 如果上一个组没消息或最后一条是CutMessage,直接continue 405 | continue 406 | 407 | if not current_group: 408 | current_group = [current_msg] 409 | continue 410 | 411 | last_msg = current_group[-1] 412 | 413 | # 判断是否是同一个人的连续消息 414 | if ( 415 | current_msg.is_sender == last_msg.is_sender 416 | and current_msg.talker == last_msg.talker 417 | and self.single_combine_strategy.is_same_conversation([last_msg], current_msg) 418 | ): 419 | current_group.append(current_msg) 420 | else: 421 | # 不是同一个人的消息,处理当前组并开始新组 422 | _combine_current_group(current_group) 423 | # 开始新组 424 | current_group = [current_msg] 425 | 426 | # 处理最后一组消息 427 | if current_group: 428 | _combine_current_group(current_group) 429 | 430 | return grouped_messages 431 | 432 | def process_by_msgtype(self, chat_message: ChatMessage): 433 | if chat_message.type_name == "文本": 434 | self.process_text(chat_message) 435 | # elif chat_message.type_name == "图片": 436 | # self.process_image(chat_message) 437 | 438 | def load_csv(self, file_path) -> List[ChatMessage]: 439 | """ 440 | 做整体第一次预处理,过滤不符合条件的行 441 | """ 442 | df = pd.read_csv(file_path, encoding="utf-8", dtype={"msg": str}) 443 | 444 | df = df[~df["type_name"].isin(values=skip_type_list)] 445 | 446 | # 如果type_name为文本 并且msg 包含 手机号、身份证号、邮箱、网址则删除这行 447 | for i in df.index: 448 | if df.loc[i, "type_name"] == "文本": 449 | msg_str = str(df.loc[i, "msg"]) 450 | if ( 451 | re.search(r"1\d{10}", msg_str) 452 | or re.search(r"\d{18}", msg_str) 453 | or re.search(r"\w+@\w+", msg_str) 454 | or "http" in msg_str 455 | or r"\\xa0" in msg_str 456 | or r"\\u" in msg_str 457 | ): 458 | df = df.drop(index=i) 459 | continue 460 | for blocked_word in self.blocked_words: 461 | if blocked_word in msg_str: 462 | df = df.drop(index=i) 463 | break 464 | else: 465 | df.loc[i, "msg"] = "" 466 | 467 | df = df.dropna(how="all") 468 | # 时间格式 2021-07-07 10:27:23 469 | # 遍历行 相同is_sender的行合并msg()遇到不同is_sender就重新开始 470 | df["CreateTime"] = pd.to_datetime(df["CreateTime"]) 471 | 472 | return [ChatMessage(*row) for row in df.values] 473 | 474 | def process_text(self, chat_message: ChatMessage): 475 | pass 476 | 477 | def save_result(self, qa_res: List[QaPair]): 478 | """ 479 | Saves the list of QaPair objects to a JSON file after converting them to dictionaries. 480 | 481 | Args: 482 | qa_res: A list of QaPair objects. 483 | """ 484 | processed_qa_res = [] 485 | for idx, item in enumerate(qa_res): 486 | item_dict = { 487 | "id": idx, 488 | "system": item.system, 489 | "instruction": item.instruction, 490 | "output": item.output, 491 | "history": item.history, 492 | "time": item.time.isoformat() if item.time else None, 493 | "score": item.score, 494 | } 495 | processed_qa_res.append(item_dict) 496 | 497 | output_path = "./dataset/res_csv/sft/sft-my.json" 498 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 499 | with open(output_path, "w", encoding="utf-8") as f: 500 | json.dump(processed_qa_res, f, ensure_ascii=False, indent=4) 501 | logger.success(f"聊天记录处理成功,共{len(qa_res)}条,保存到 {output_path}") 502 | 503 | 504 | if __name__ == "__main__": 505 | processor = DataProcessor() 506 | processor.main() 507 | -------------------------------------------------------------------------------- /weclone/data/strategies.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List 3 | from .models import ChatMessage 4 | from abc import ABC, abstractmethod 5 | 6 | 7 | @dataclass 8 | class ConversationStrategy(ABC): 9 | """对话策略的抽象基类""" 10 | 11 | is_single_chat: bool 12 | 13 | @abstractmethod 14 | def is_same_conversation( 15 | self, history_msg: List[ChatMessage], current_msg: ChatMessage 16 | ) -> bool: 17 | """判断两条消息是否属于同一个对话""" 18 | pass 19 | 20 | 21 | @dataclass 22 | class TimeWindowStrategy(ConversationStrategy): 23 | """基于时间窗口的判断策略""" 24 | 25 | time_window: int # 时间窗口(分钟) 26 | 27 | def is_same_conversation( 28 | self, history_msg: List[ChatMessage], current_msg: ChatMessage 29 | ) -> bool: 30 | time_diff = abs( 31 | (current_msg.CreateTime - history_msg[-1].CreateTime) 32 | ).total_seconds() 33 | return time_diff <= self.time_window 34 | 35 | 36 | @dataclass 37 | class LLMStrategy(ConversationStrategy): 38 | """基于大模型判断策略""" 39 | 40 | def is_same_conversation( 41 | self, history_msg: List[ChatMessage], current_msg: ChatMessage 42 | ) -> bool: 43 | # 修复user_id错误,使用talker字段代替user_id 44 | return current_msg.talker == history_msg[-1].talker if history_msg else False 45 | 46 | 47 | @dataclass 48 | class CompositeStrategy(ConversationStrategy): 49 | """组合多个策略的复合策略""" 50 | 51 | strategies: List[ConversationStrategy] 52 | require_all: bool = True # True表示所有策略都满足,False表示任一策略满足即可 53 | 54 | def is_same_conversation( 55 | self, history_msg: List[ChatMessage], current_msg: ChatMessage 56 | ) -> bool: 57 | results = [ 58 | s.is_same_conversation(history_msg, current_msg) for s in self.strategies 59 | ] 60 | return all(results) if self.require_all else any(results) 61 | -------------------------------------------------------------------------------- /weclone/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/eval/__init__.py -------------------------------------------------------------------------------- /weclone/eval/cli_demo.py: -------------------------------------------------------------------------------- 1 | from llamafactory.chat import ChatModel 2 | from llamafactory.extras.misc import torch_gc 3 | 4 | 5 | def main(): 6 | try: 7 | import platform 8 | 9 | if platform.system() != "Windows": 10 | import readline # noqa: F401 11 | except ImportError: 12 | print("Install `readline` for a better experience.") 13 | 14 | chat_model = ChatModel() 15 | messages = [] 16 | print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") 17 | 18 | while True: 19 | try: 20 | query = input("\nUser: ") 21 | except UnicodeDecodeError: 22 | print("Detected decoding error at the inputs, please set the terminal encoding to utf-8.") 23 | continue 24 | except Exception: 25 | raise 26 | 27 | if query.strip() == "exit": 28 | break 29 | 30 | if query.strip() == "clear": 31 | messages = [] 32 | torch_gc() 33 | print("History has been removed.") 34 | continue 35 | 36 | messages.append({"role": "user", "content": query}) 37 | print("Assistant: ", end="", flush=True) 38 | 39 | response = "" 40 | for new_text in chat_model.stream_chat(messages): 41 | print(new_text, end="", flush=True) 42 | response += new_text 43 | print() 44 | messages.append({"role": "assistant", "content": response}) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /weclone/eval/eval_model.py: -------------------------------------------------------------------------------- 1 | from llamafactory.eval.evaluator import Evaluator 2 | 3 | 4 | def main(): 5 | evaluator = Evaluator() 6 | evaluator.eval() 7 | 8 | 9 | if __name__ == "__main__": 10 | main() 11 | -------------------------------------------------------------------------------- /weclone/eval/test_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import openai 3 | from openai import OpenAI # 导入 OpenAI 类 4 | 5 | from tqdm import tqdm 6 | from typing import List, Dict, cast # 导入 cast 7 | from openai.types.chat import ChatCompletionMessageParam # 导入消息参数类型 8 | 9 | from weclone.utils.config import load_config 10 | 11 | config = load_config("web_demo") 12 | 13 | config = { 14 | "default_prompt": config["default_system"], 15 | "model": "gpt-3.5-turbo", 16 | "history_len": 15, 17 | } 18 | 19 | config = type("Config", (object,), config)() 20 | 21 | # 初始化 OpenAI 客户端 22 | client = OpenAI( 23 | api_key="""sk-test""", 24 | base_url="http://127.0.0.1:8005/v1" 25 | ) 26 | 27 | 28 | def handler_text(content: str, history: list, config): 29 | messages = [{"role": "system", "content": f"{config.default_prompt}"}] 30 | for item in history: 31 | messages.append(item) 32 | messages.append({"role": "user", "content": content}) 33 | history.append({"role": "user", "content": content}) 34 | try: 35 | # 使用新的 API 调用方式 36 | # 将 messages 转换为正确的类型 37 | typed_messages = cast(List[ChatCompletionMessageParam], messages) 38 | response = client.chat.completions.create( 39 | model=config.model, 40 | messages=typed_messages, # 传递转换后的列表 41 | max_tokens=50 42 | ) 43 | except openai.APIError as e: 44 | history.pop() 45 | return "AI接口出错,请重试\n" + str(e) 46 | 47 | resp = str(response.choices[0].message.content) # type: ignore 48 | resp = resp.replace("\n ", "") 49 | history.append({"role": "assistant", "content": resp}) 50 | return resp 51 | 52 | 53 | def main(): 54 | test_list = json.loads(open("dataset/test_data.json", "r", encoding="utf-8").read())["questions"] 55 | res = [] 56 | for questions in tqdm(test_list, desc=" Testing..."): 57 | history = [] 58 | for q in questions: 59 | handler_text(q, history=history, config=config) 60 | res.append(history) 61 | 62 | res_file = open("test_result-my.txt", "w") 63 | for r in res: 64 | for i in r: 65 | res_file.write(i["content"] + "\n") 66 | res_file.write("\n") 67 | 68 | 69 | if __name__ == "__main__": 70 | main() 71 | -------------------------------------------------------------------------------- /weclone/eval/web_demo.py: -------------------------------------------------------------------------------- 1 | from llamafactory.webui.interface import create_web_demo 2 | from weclone.utils.config import load_config 3 | 4 | 5 | def main(): 6 | config = load_config("web_demo") 7 | demo = create_web_demo() 8 | demo.queue() 9 | demo.launch(server_name="0.0.0.0", share=True, inbrowser=True) 10 | 11 | 12 | if __name__ == "__main__": 13 | main() 14 | -------------------------------------------------------------------------------- /weclone/prompts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/prompts/__init__.py -------------------------------------------------------------------------------- /weclone/prompts/clean_data.py: -------------------------------------------------------------------------------- 1 | CLEAN_PROMPT = """ 2 | # 角色 3 | 你是一个数据质量评估员。 4 | 5 | # 任务 6 | 你的任务是评估下面提供的【回答 A】相对于【问题/上下文 Q】的**逻辑性**和**相关性**。目标是识别并帮助过滤掉那些回答与问题**明显不匹配**、**逻辑严重混乱**的数据对。请根据以下核心评估点给出一个1到5的整数分数,并将该分数与原始 `id` 一起输出。 7 | 8 | **重要考量:** 9 | 1. **简短回答的有效性:** 请注意,诸如“好的”、“是的”、“收到”、“嗯”、“知道了”等简短的肯定、确认或应答,在合适的语境下是完全**有逻辑且相关的**。**不要仅仅因为回答简短就将其评为低分。** 只有当这类简短回答与【问题/上下文 Q】**明显不符**时,才应考虑低分。 10 | 2. **处理错别字和自我纠正:** 聊天记录中可能包含常见的打字错误(错别字)或用户先打错字随后又自行纠正的情况(例如,发送“我想去1楼”紧接着又发送“*2楼”进行更正)。在评估时,请**聚焦于用户想要表达的最终意图和信息的核心内容**,而**不应仅仅因为存在错别字或纠正过程就判定为低质量**。。 11 | 12 | 13 | # 核心评估点 (请在心中衡量) 14 | 1. **相关性 (Relevance):** 【回答 A】是否直接回应或恰当地衔接了【问题/上下文 Q】?它是在回答问题,还是完全跑题了?只有当【回答 A】与【问题/上下文 Q】**明显矛盾**、**完全不着边际**(即使考虑上下文也无法合理化),或简短回答**明显不适用于**该【问题/上下文 Q】时,才给予低分。 15 | 2. **逻辑性 (Coherence):** 【回答 A】本身是否符合基本的逻辑?结合【问题/上下文 Q】来看,这个问答对是否构成了一个符合逻辑的交流片段?是否存在明显的矛盾、混乱的内容?只有当【回答 A】**自身逻辑混乱**、**与Q存在无法解释的矛盾**时,才给予低分。 16 | 17 | # 评分标准 (1-5分) 18 | * **1分 (极差):** 完全不相关;逻辑严重混乱/矛盾。 19 | * **2分 (差):** 相关性很低;存在明显的逻辑问题或不连贯。 20 | * **3分 (中等):** 相关性一般(可能部分跑题或回应不充分);逻辑上勉强说得通但不够流畅或有瑕疵。 21 | * **4分 (良好):** 相关性好,回答了问题或恰当衔接;逻辑清晰。 22 | * **5分 (优秀):** 相关性强,回应精准;逻辑严谨流畅。 23 | 24 | # 输入数据 25 | ```json 26 | {{ 27 | "id": "{id}", 28 | "Q": "{Q}", 29 | "A": "{A}" 30 | }} 31 | 32 | # 输出要求 33 | 请严格按照以下 JSON 格式输出,包含原始的 id 和你给出的1到5的整数评分 score,不要包含任何其他文字、解释或标签。 34 | {{ 35 | "id": "<这里填入输入数据中的id值>", 36 | "score": <这里填入1到5的整数评分> 37 | }} 38 | """ 39 | 40 | ONLINE_LLM_CLEAN_PROMPT = """ 41 | # 角色 42 | 你是一个数据质量评估员。 43 | 44 | # 任务 45 | 你的任务是评估下面提供的【回答 A】相对于【问题/上下文 Q】的**逻辑性**和**相关性**。目标是识别并帮助过滤掉那些回答与问题**明显不匹配**、**逻辑严重混乱**的数据对。请根据以下核心评估点给出一个1到5的整数分数,并将该分数与原始 `id` 一起输出。 46 | 47 | **重要考量:** 48 | 1. **简短回答的有效性:** 请注意,诸如“好的”、“是的”、“收到”、“嗯”、“知道了”等简短的肯定、确认或应答,在合适的语境下是完全**有逻辑且相关的**。**不要仅仅因为回答简短就将其评为低分。** 只有当这类简短回答与【问题/上下文 Q】**明显不符**时,才应考虑低分。 49 | 2. **处理错别字和自我纠正:** 聊天记录中可能包含常见的打字错误(错别字)或用户先打错字随后又自行纠正的情况(例如,发送“我想去1楼”紧接着又发送“*2楼”进行更正)。在评估时,请**聚焦于用户想要表达的最终意图和信息的核心内容**,而**不应仅仅因为存在错别字或纠正过程就判定为低质量**。。 50 | 51 | 52 | # 核心评估点 (请在心中衡量) 53 | 1. **相关性 (Relevance):** 【回答 A】是否直接回应或恰当地衔接了【问题/上下文 Q】?它是在回答问题,还是完全跑题了?只有当【回答 A】与【问题/上下文 Q】**明显矛盾**、**完全不着边际**(即使考虑上下文也无法合理化),或简短回答**明显不适用于**该【问题/上下文 Q】时,才给予低分。 54 | 2. **逻辑性 (Coherence):** 【回答 A】本身是否符合基本的逻辑?结合【问题/上下文 Q】来看,这个问答对是否构成了一个符合逻辑的交流片段?是否存在明显的矛盾、混乱的内容?只有当【回答 A】**自身逻辑混乱**、**与Q存在无法解释的矛盾**时,才给予低分。 55 | 56 | # 评分标准 (1-5分) 57 | * **1分 (极差):** 完全不相关;逻辑严重混乱/矛盾。 58 | * **2分 (差):** 相关性很低;存在明显的逻辑问题或不连贯。 59 | * **3分 (中等):** 相关性一般(可能部分跑题或回应不充分);逻辑上勉强说得通但不够流畅或有瑕疵。 60 | * **4分 (良好):** 相关性好,回答了问题或恰当衔接;逻辑清晰。 61 | * **5分 (优秀):** 相关性强,回应精准;逻辑严谨流畅。 62 | 63 | # 输入数据 64 | ```json 65 | {qa_list} 66 | 67 | # 输出要求 68 | 请严格按照以下 JSON 格式输出,包含原始的 id 和你给出的1到5的整数评分 score,不要包含任何其他文字、解释或标签! 69 | [ 70 | {{ 71 | "id": "<这里填入第1条输入数据中的id值>", 72 | "score": <1-5的整数评分> 73 | }}, 74 | {{ 75 | "id": "<这里填入第2条输入数据中的id值>", 76 | "score": <1-5的整数评分> 77 | }} 78 | … 79 | ] 80 | """ -------------------------------------------------------------------------------- /weclone/server/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/server/__init__.py -------------------------------------------------------------------------------- /weclone/server/api_service.py: -------------------------------------------------------------------------------- 1 | import os 2 | import uvicorn 3 | from llamafactory.chat import ChatModel 4 | from llamafactory.api.app import create_app 5 | from weclone.utils.config import load_config 6 | 7 | 8 | 9 | def main(): 10 | config = load_config("api_service") 11 | chat_model = ChatModel(config) 12 | app = create_app(chat_model) 13 | print("Visit http://localhost:{}/docs for API document.".format(os.environ.get("API_PORT", 8005))) 14 | uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("API_PORT", 8005)), workers=1) 15 | 16 | 17 | if __name__ == "__main__": 18 | main() 19 | -------------------------------------------------------------------------------- /weclone/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/train/__init__.py -------------------------------------------------------------------------------- /weclone/train/export_model.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import export_model 2 | 3 | 4 | def main(): 5 | export_model() 6 | 7 | 8 | if __name__ == "__main__": 9 | main() 10 | -------------------------------------------------------------------------------- /weclone/train/train_pt.py: -------------------------------------------------------------------------------- 1 | from llamafactory.train.tuner import run_exp 2 | from weclone.utils.config import load_config 3 | 4 | config = load_config("train_pt") 5 | run_exp(config) 6 | -------------------------------------------------------------------------------- /weclone/train/train_sft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | from llamafactory.train.tuner import run_exp 5 | from llamafactory.extras.misc import get_current_device 6 | from weclone.utils.config import load_config 7 | from weclone.utils.log import logger 8 | from weclone.data.clean.strategies import LLMCleaningStrategy 9 | 10 | def main(): 11 | train_config = load_config(arg_type="train_sft") 12 | dataset_config = load_config(arg_type="make_dataset") 13 | 14 | device = get_current_device() 15 | if device == "cpu": 16 | logger.warning("请注意你正在使用CPU训练,非Mac设备可能会出现问题") 17 | 18 | cleaner = LLMCleaningStrategy(make_dataset_config=dataset_config) 19 | cleaned_data_path = cleaner.clean() 20 | 21 | if not os.path.exists(cleaned_data_path): 22 | logger.error(f"错误:文件 '{cleaned_data_path}' 不存在,请确保数据处理步骤已正确生成该文件。") 23 | sys.exit(1) 24 | 25 | formatted_config = json.dumps(train_config, indent=4, ensure_ascii=False) 26 | logger.info(f"微调配置:\n{formatted_config}") 27 | 28 | run_exp(train_config) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /weclone/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xming521/WeClone/94aa25bdea5cb2c9d6303de6c7bcd2c1a152a34b/weclone/utils/__init__.py -------------------------------------------------------------------------------- /weclone/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import commentjson 3 | import sys 4 | 5 | from .log import logger 6 | from .tools import dict_to_argv 7 | 8 | 9 | def load_config(arg_type: str): 10 | config_path = os.environ.get("WECLONE_CONFIG_PATH", "./settings.jsonc") 11 | logger.info(f"Loading configuration from: {config_path}") # Add logging to see which file is loaded 12 | try: 13 | with open(config_path, "r", encoding="utf-8") as f: 14 | s_config: dict = commentjson.load(f) 15 | except FileNotFoundError: 16 | logger.error(f"Configuration file not found: {config_path}") 17 | sys.exit(1) # Exit if config file is not found 18 | except Exception as e: 19 | logger.error(f"Error loading configuration file {config_path}: {e}") 20 | sys.exit(1) 21 | 22 | if arg_type == "cli_args": 23 | config = s_config["cli_args"] 24 | elif arg_type == "web_demo" or arg_type == "api_service": 25 | # infer_args和common_args求并集 26 | config = {**s_config["infer_args"], **s_config["common_args"]} 27 | elif arg_type == "train_pt": 28 | config = {**s_config["train_pt_args"], **s_config["common_args"]} 29 | elif arg_type == "train_sft": 30 | config = {**s_config["train_sft_args"], **s_config["common_args"]} 31 | if s_config["make_dataset_args"]["prompt_with_history"]: 32 | dataset_info_path = os.path.join(config["dataset_dir"], "dataset_info.json") 33 | dataset_info = commentjson.load(open(dataset_info_path, "r", encoding="utf-8"))[config["dataset"]] 34 | if dataset_info["columns"].get("history") is None: 35 | logger.warning(f"{config['dataset']}数据集不包history字段,尝试使用wechat-sft-with-history数据集") 36 | config["dataset"] = "wechat-sft-with-history" 37 | 38 | elif arg_type == "make_dataset": 39 | config = {**s_config["make_dataset_args"], **s_config["common_args"]} 40 | config["dataset"] = s_config["train_sft_args"]["dataset"] 41 | config["dataset_dir"] = s_config["train_sft_args"]["dataset_dir"] 42 | config["cutoff_len"] = s_config["train_sft_args"]["cutoff_len"] 43 | else: 44 | raise ValueError("暂不支持的参数类型") 45 | 46 | if "train" in arg_type: 47 | config["output_dir"] = config["adapter_name_or_path"] 48 | config.pop("adapter_name_or_path") 49 | config["do_train"] = True 50 | 51 | sys.argv += dict_to_argv(config) 52 | 53 | return config 54 | -------------------------------------------------------------------------------- /weclone/utils/length_cdf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 the LlamaFactory team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import defaultdict 16 | 17 | import fire 18 | from tqdm import tqdm 19 | from weclone.utils.log import logger 20 | 21 | from llamafactory.data import get_dataset, get_template_and_fix_tokenizer 22 | from llamafactory.hparams import get_train_args 23 | from llamafactory.model import load_tokenizer 24 | 25 | 26 | def length_cdf( 27 | model_name_or_path: str = "./Qwen2.5-7B-Instruct", 28 | dataset: str = "wechat-sft", 29 | dataset_dir: str = "./dataset/res_csv/sft", 30 | template: str = "qwen", 31 | interval: int = 256, 32 | ): 33 | r"""Calculate the distribution of the input lengths in the dataset. 34 | 35 | Usage: export CUDA_VISIBLE_DEVICES=0 36 | python length_cdf.py --model_name_or_path path_to_model --dataset alpaca_en_demo --template default 37 | """ 38 | logger.info("开始计算cutoff_len......") 39 | 40 | model_args, data_args, training_args, _, _ = get_train_args( 41 | { 42 | "stage": "sft", 43 | "model_name_or_path": model_name_or_path, 44 | "dataset": dataset, 45 | "dataset_dir": dataset_dir, 46 | "template": template, 47 | "cutoff_len": 1_000_000, 48 | "preprocessing_num_workers": 16, 49 | "output_dir": "dummy_dir", 50 | "overwrite_cache": True, 51 | "do_train": True, 52 | } 53 | ) 54 | tokenizer_module = load_tokenizer(model_args) 55 | template = get_template_and_fix_tokenizer(tokenizer_module["tokenizer"], data_args) # type: ignore 56 | trainset = get_dataset(template, model_args, data_args, training_args, "sft", **tokenizer_module)["train_dataset"] # type: ignore 57 | total_num = len(trainset) # type: ignore 58 | length_dict = defaultdict(int) 59 | for sample in tqdm(trainset["input_ids"], desc="Collecting lengths"): # type: ignore 60 | length_dict[len(sample) // interval * interval] += 1 61 | 62 | length_tuples = list(length_dict.items()) 63 | length_tuples.sort() 64 | count_accu, prob_accu = 0, 0 65 | logger.info(" cutoff_len设置建议:") 66 | for length, count in length_tuples: 67 | count_accu += count 68 | prob_accu += count / total_num * 100 69 | logger.success(f"{count_accu:d} ({prob_accu:.2f}%) samples have length < {length + interval}.") 70 | 71 | 72 | if __name__ == "__main__": 73 | fire.Fire(length_cdf) 74 | -------------------------------------------------------------------------------- /weclone/utils/log.py: -------------------------------------------------------------------------------- 1 | from loguru import logger 2 | import sys 3 | from functools import wraps 4 | 5 | logger.remove() 6 | 7 | logger.add( 8 | sys.stderr, 9 | format="[WeClone] {level.name[0]} | {time:HH:mm:ss} | {message}", 10 | colorize=True, 11 | level="INFO", 12 | ) 13 | 14 | logger.add( 15 | "logs/weclone.log", # 日志文件路径 16 | rotation="1 day", # 每天轮换一个新的日志文件 17 | retention="7 days", # 保留最近7天的日志文件 18 | compression="zip", # 压缩旧的日志文件 19 | level="DEBUG", # 文件日志级别 20 | format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}", # 日志格式 21 | encoding="utf-8", # 文件编码 22 | enqueue=True, # 异步写入,避免阻塞 23 | ) 24 | 25 | 26 | def capture_output(func): 27 | @wraps(func) 28 | def wrapper(*args, **kwargs): 29 | log_sink_buffer = [] 30 | 31 | def list_sink(message): 32 | log_sink_buffer.append(message.record["message"]) 33 | 34 | sink_id = logger.add(list_sink, format="{message}", level="INFO") 35 | 36 | original_stdout = sys.stdout 37 | original_stderr = sys.stderr 38 | 39 | class OutputTeeToGlobalLog: 40 | def __init__(self, original_stream, log_method): 41 | self.original_stream = original_stream 42 | self.log_method = log_method 43 | self.current_line_content = "" # Represents the current state of the line to be logged 44 | 45 | def write(self, data_chunk): 46 | self.original_stream.write(data_chunk) # Pass through to console 47 | 48 | if data_chunk.endswith("\\r") and "\\n" not in data_chunk: 49 | self.current_line_content = data_chunk[:-1] # Store without the trailing \\r 50 | return 51 | 52 | full_buffer = self.current_line_content + data_chunk 53 | lines_to_process = full_buffer.split("\\n") 54 | 55 | for i in range(len(lines_to_process) - 1): 56 | line = lines_to_process[i] 57 | final_content_of_line = line 58 | last_cr = line.rfind("\\r") 59 | if last_cr != -1: 60 | final_content_of_line = line[last_cr + 1 :] 61 | 62 | escaped_log = final_content_of_line.replace("{", "{{").replace("}", "}}") 63 | if final_content_of_line.strip() or line: 64 | self.log_method(escaped_log, raw=True) 65 | 66 | self.current_line_content = lines_to_process[-1] 67 | 68 | def flush(self): 69 | self.original_stream.flush() 70 | if self.current_line_content: 71 | final_content_of_line = self.current_line_content 72 | last_cr = self.current_line_content.rfind("\\r") 73 | if last_cr != -1: 74 | final_content_of_line = self.current_line_content[last_cr + 1 :] 75 | 76 | escaped_log = final_content_of_line.replace("{", "{{").replace("}", "}}") 77 | if final_content_of_line.strip() or self.current_line_content: 78 | self.log_method(escaped_log, raw=True) 79 | self.current_line_content = "" 80 | 81 | sys.stdout = OutputTeeToGlobalLog(original_stdout, logger.opt(raw=True).info) 82 | sys.stderr = OutputTeeToGlobalLog(original_stderr, logger.opt(raw=True).error) 83 | 84 | try: 85 | func(*args, **kwargs) 86 | finally: 87 | sys.stdout = original_stdout 88 | sys.stderr = original_stderr 89 | logger.remove(sink_id) 90 | 91 | return wrapper 92 | -------------------------------------------------------------------------------- /weclone/utils/tools.py: -------------------------------------------------------------------------------- 1 | def dict_to_argv(d): 2 | argv = [] 3 | for k, v in d.items(): 4 | argv.append("--" + k) 5 | if v is not None: 6 | argv.append(str(v)) 7 | return argv 8 | 9 | 10 | --------------------------------------------------------------------------------