├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ ├── config.yml
│ └── feature_request.yml
├── .gitignore
├── LICENSE
├── MODEL_LICENSE
├── PROJECT.md
├── README.md
├── README_en.md
├── cli_demo.py
├── examples
├── ad-writing-2.png
├── blog-outline.png
├── comments-writing.png
├── email-writing-1.png
├── email-writing-2.png
├── information-extraction.png
├── role-play.png
├── self-introduction.png
├── sport.png
└── tour-guide.png
├── improve
├── README.md
└── data_sample.jsonl
├── limitations
├── factual_error.png
├── math_error.png
├── self-confusion_google.jpg
├── self-confusion_openai.jpg
└── self-confusion_tencent.jpg
├── model_api.py
├── ptuning
├── README.md
├── README_en.md
├── arguments.py
├── deepspeed.json
├── ds_train_finetune.sh
├── evaluate.sh
├── evaluate_finetune.sh
├── main.py
├── train.sh
├── train_chat.sh
├── trainer.py
├── trainer_seq2seq.py
├── web_demo.py
└── web_demo.sh
├── requirements.txt
├── resources
├── WECHAT.md
├── cli-demo.png
├── web-demo.gif
├── web-demo.png
└── wechat.jpg
├── start.sh
├── utils.py
├── web_demo.py
├── web_index.html
└── web_ui.py
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: 🐞 Bug/Help
2 | description: File a bug/issue
3 | title: "[BUG/Help]
"
4 | labels: []
5 | body:
6 | - type: checkboxes
7 | attributes:
8 | label: Is there an existing issue for this?
9 | description: Please search to see if an issue already exists for the bug you encountered.
10 | options:
11 | - label: I have searched the existing issues
12 | required: true
13 | - type: textarea
14 | attributes:
15 | label: Current Behavior
16 | description: |
17 | A concise description of what you're experiencing, with screenshot attached if possible.
18 | Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
19 | validations:
20 | required: true
21 | - type: textarea
22 | attributes:
23 | label: Expected Behavior
24 | description: A concise description of what you expected to happen.
25 | validations:
26 | required: false
27 | - type: textarea
28 | attributes:
29 | label: Steps To Reproduce
30 | description: Steps to reproduce the behavior.
31 | placeholder: |
32 | 1. In this environment...
33 | 2. With this config...
34 | 3. Run '...'
35 | 4. See error...
36 | validations:
37 | required: true
38 | - type: textarea
39 | attributes:
40 | label: Environment
41 | description: |
42 | examples:
43 | - **OS**: Ubuntu 20.04
44 | - **Python**: 3.8
45 | - **Transformers**: 4.26.1
46 | - **PyTorch**: 1.12
47 | - **CUDA Support**: True
48 | value: |
49 | - OS:
50 | - Python:
51 | - Transformers:
52 | - PyTorch:
53 | - CUDA Support (`python -c "import torch; print(torch.cuda.is_available())"`) :
54 | render: markdown
55 | validations:
56 | required: true
57 | - type: textarea
58 | attributes:
59 | label: Anything else?
60 | description: |
61 | Links? References? Anything that will give us more context about the issue you are encountering!
62 | validations:
63 | required: false
64 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: false
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.yml:
--------------------------------------------------------------------------------
1 | name: Feature request
2 | description: Suggest an idea for this project
3 | title: "[Feature] "
4 | labels: []
5 | body:
6 | - type: textarea
7 | attributes:
8 | label: Is your feature request related to a problem? Please describe.
9 | description: |
10 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
11 | validations:
12 | required: false
13 | - type: textarea
14 | attributes:
15 | label: Solutions
16 | description: |
17 | Describe the solution you'd like
18 | A clear and concise description of what you want to happen.
19 | validations:
20 | required: true
21 | - type: textarea
22 | attributes:
23 | label: Additional context
24 | description: Add any other context or screenshots about the feature request here.
25 | validations:
26 | required: false
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 | history/
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 |
132 | # Mac system file
133 | model/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright Zhengxiao Du
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The ChatGLM-6B License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means the ChatGLM-6B Model Team that distributes its Software.
6 |
7 | “Software” means the ChatGLM-6B model parameters made available under this license.
8 |
9 | 2. License Grant
10 |
11 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
12 |
13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
14 |
15 | 3. Restriction
16 |
17 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
18 |
19 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
20 |
21 | 4. Disclaimer
22 |
23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |
25 | 5. Limitation of Liability
26 |
27 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
28 |
29 | 6. Dispute Resolution
30 |
31 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
32 |
33 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at glm-130b@googlegroups.com.
34 |
--------------------------------------------------------------------------------
/PROJECT.md:
--------------------------------------------------------------------------------
1 | # 友情链接
2 |
3 | 对 ChatGLM 进行加速或者重新实现的开源项目:
4 | * [SwissArmyTransformer](https://github.com/THUDM/SwissArmyTransformer): 一个Transformer统一编程框架,ChatGLM-6B已经在SAT中进行实现并可以进行P-tuning微调。
5 | * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): 一个基于 MNN 的 ChatGLM-6B C++ 推理实现,支持根据显存大小自动分配计算任务给 GPU 和 CPU
6 | * [JittorLLMs](https://github.com/Jittor/JittorLLMs):最低3G显存或者没有显卡都可运行 ChatGLM-6B FP16, 支持Linux、windows、Mac部署
7 |
8 |
9 |
10 | 基于或使用了 ChatGLM-6B 的开源项目:
11 | * [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): 支持ChatGLM-6B的学术写作与编程工具箱,具有模块化和多线程调用LLM的特点,可并行调用多种LLM。
12 | * [闻达](https://github.com/l15y/wenda):大型语言模型调用平台,基于 ChatGLM-6B 实现了类 ChatPDF 功能
13 | * [glm-bot](https://github.com/initialencounter/glm-bot):将ChatGLM接入Koishi可在各大聊天平台上调用ChatGLM
14 | * [Chinese-LangChain](https://github.com/yanqiangmiffy/Chinese-LangChain):中文langchain项目,基于ChatGLM-6b+langchain实现本地化知识库检索与智能答案生成,增加web search功能、知识库选择功能和支持知识增量更新
15 | * [bibliothecarius](https://github.com/coderabbit214/bibliothecarius):快速构建服务以集成您的本地数据和AI模型,支持ChatGLM等本地化模型接入。
16 | * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM):基于 langchain 的 ChatGLM 应用,实现基于可扩展知识库的问答
17 | * [ChatGLM-web](https://github.com/NCZkevin/chatglm-web):基于FastAPI和Vue3搭建的ChatGLM演示网站(支持chatglm流式输出、前端调整模型参数、上下文选择、保存图片、知识库问答等功能)
18 | * [ChatGLM-6B-Engineering](https://github.com/LemonQu-GIT/ChatGLM-6B-Engineering):基于 ChatGLM-6B 后期调教,网络爬虫及 [Stable Diffusion](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 实现的网络搜索及图片生成
19 | * [ChatGLM-OpenAI-API](https://github.com/ninehills/chatglm-openai-api): 将 ChatGLM-6B 封装为 OpenAI API 风格,并通过 ngrok/cloudflare 对外提供服务,从而将 ChatGLM 快速集成到 OpenAI 的各种生态中。
20 |
21 | 对 ChatGLM-6B 进行微调的开源项目:
22 | * [InstructGLM](https://github.com/yanqiangmiffy/InstructGLM):基于ChatGLM-6B进行指令学习,汇总开源中英文指令数据,基于Lora进行指令数据微调,开放了Alpaca、Belle微调后的Lora权重,修复web_demo重复问题
23 | * [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning):实现了ChatGLM-6B模型的监督微调和完整RLHF训练,汇总10余种指令数据集和3种微调方案,实现了4/8比特量化和模型权重融合,提供微调模型快速部署方法。
24 | * [ChatGLM-Finetuning](https://github.com/liucongg/ChatGLM-Finetuning):基于ChatGLM-6B模型,进行下游具体任务微调,涉及Freeze、Lora、P-tuning等,并进行实验效果对比。
25 | * [ChatGLM-Tuning](https://github.com/mymusise/ChatGLM-Tuning): 基于 LoRA 对 ChatGLM-6B 进行微调。类似的项目还包括 [Humanable ChatGLM/GPT Fine-tuning | ChatGLM 微调](https://github.com/hscspring/hcgf)
26 |
27 |
28 | 针对 ChatGLM-6B 的教程/文档:
29 | * [Windows部署文档](https://github.com/ZhangErling/ChatGLM-6B/blob/main/deployment_windows.md)
30 | * [搭建深度学习docker容器以运行 ChatGLM-6B - Luck_zy](https://www.luckzym.com/tags/ChatGLM-6B/)
31 |
32 | 如果你有其他好的项目/教程的话,欢迎参照上述格式添加到 README 中并提出 [Pull Request](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork)。
33 |
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # audioConversation-ChatGLM
2 |
3 | 该项目是基于ChatGLM实现的一个语音对话工具(英语),其中涉及到的AI工具基本都是开源的,整个过程不需要翻墙,不需要使用openAI的key,只需要本地GPU大于等于12G即可。
4 |
5 | ### 环境安装
6 |
7 | 首先需要使用以下命令安装ffmpeg工具
8 | ```shell
9 | sudo apt install ffmpeg
10 | ```
11 |
12 | 使用 pip 安装依赖:`pip install -r requirements.txt`,其中 `transformers` 库版本推荐为 `4.27.1`,但理论上不低于 `4.23.1` 即可。
13 |
14 | 此外,如果需要在 cpu 上运行量化后的模型,还需要安装 `gcc` 与 `openmp`。多数 Linux 发行版默认已安装。对于 Windows ,可在安装 [TDM-GCC](https://jmeubank.github.io/tdm-gcc/) 时勾选 `openmp`。 Windows 测试环境 `gcc` 版本为 `TDM-GCC 10.3.0`, Linux 为 `gcc 11.3.0`。
15 |
16 | 如果你的内存不足,可以直接加载量化后的模型:
17 |
18 | ```python
19 | # INT8 量化的模型将"THUDM/chatglm-6b-int4"改为"THUDM/chatglm-6b-int8"
20 | model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4",trust_remote_code=True).float()
21 | ```
22 |
23 | 如果遇到了报错 `Could not find module 'nvcuda.dll'` 或者 `RuntimeError: Unknown platform: darwin` (MacOS) ,请[从本地加载模型](README.md#从本地加载模型)
24 |
25 | ### Mac 上的 GPU 加速
26 | 对于搭载了Apple Silicon的Mac(以及MacBook),可以使用 MPS 后端来在 GPU 上运行 ChatGLM-6B。需要参考 Apple 的 [官方说明](https://developer.apple.com/metal/pytorch) 安装 PyTorch-Nightly。
27 |
28 | 目前在 MacOS 上只支持[从本地加载模型](README.md#从本地加载模型)。将代码中的模型加载改为从本地加载,并使用 mps 后端
29 | ```python
30 | model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
31 | ```
32 | 即可使用在 Mac 上使用 GPU 加速模型推理。
33 |
34 |
35 | ### 本地启动
36 | 在Linux环境下运行start.sh脚本即可,然后使用chrome浏览器或者MicroSoft EDGE浏览器打开web_index.html,给予语音权限,即可开始对话。
37 |
38 | ### 未来规划
39 | 当前项目并不完善,前端页面很简陋,后续会优化页面展示以及功能,包括对话历史等。同时,也会增加其他模型的接入(主要是各类小模型),最终希望能够在不使用GPU或者使用低性能GPU的情况下依旧能够实现流畅对话的能力。
40 |
41 | ## 协议
42 |
43 | 本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
44 |
45 |
46 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B
2 |
3 |
4 | 🌐 Blog • 🤗 HF Repo • 🐦 Twitter • 📃 [GLM@ACL 22] [GitHub] • 📃 [GLM-130B@ICLR 23] [GitHub]
5 |
6 |
7 | 👋 Join our Slack and WeChat
8 |
9 |
10 | ## Introduction
11 |
12 | ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level).
13 |
14 | ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
15 |
16 | In order to facilitate downstream developers to customize the model for their own application scenarios, we also implements an parameter-efficient tuning method based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2)[(Guidelines)](ptuning/README_en.md). Tuning requires at least 7GB of GPU memory at INT4 quantization level.
17 |
18 | Try the [online demo](https://huggingface.co/spaces/ysharma/ChatGLM-6b_Gradio_Streaming) on Huggingface Spaces.
19 |
20 | ## Projects
21 | Open source projects that accelerate ChatGLM:
22 | * [ChatGLM-MNN](https://github.com/wangzhaode/ChatGLM-MNN): An MNN-based implementation of ChatGLM-6B C++ inference, which supports automatic allocation of computing tasks to GPU and CPU according to the size of GPU memory
23 | * [JittorLLMs](https://github.com/Jittor/JittorLLMs): Running ChatGLM-6B in FP16 with a minimum of 3G GPU memory or no GPU at all, with Linux, windows, and Mac support
24 |
25 | Open source projects using ChatGLM-6B:
26 | * [langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM): ChatGLM application based on langchain, realizing Q&A based on extensible knowledge base
27 | * [Wenda](https://github.com/l15y/wenda): Large-scale language model call platform, based on ChatGLM-6B to achieve ChatPDF-like functions
28 | * [chatgpt_academic](https://github.com/binary-husky/chatgpt_academic): An academic writing and programming toolbox that supports ChatGLM-6B. It has the characteristics of modularization and multi-thread calling LLM, and can call multiple LLMs in parallel.
29 | * [glm-bot](https://github.com/initialencounter/glm-bot): Connect ChatGLM to Koishi to call ChatGLM on major chat platforms
30 |
31 | Example projects supporting online training of ChatGLM-6B and related applications:
32 | * [ChatGLM-6B deployment and fine-tuning tutorial](https://www.heywhale.com/mw/project/6436d82948f7da1fee2be59e)
33 | * [ChatGLM-6B combined with langchain to implement local knowledge base QA Bot](https://www.heywhale.com/mw/project/643977aa446c45f4592a1e59)
34 |
35 | Third-party evaluation:
36 | * [Measuring Massive Multitask Chinese Understanding](https://arxiv.org/abs/2304.12986)
37 |
38 | For more open source projects, see [PROJECT.md](PROJECT.md)
39 |
40 | ## Getting Started
41 |
42 | ### Hardware Requirements
43 |
44 | | **Quantization Level** | **GPU Memory** |
45 | |------------------------|----------------|
46 | | FP16(no quantization) | 13 GB |
47 | | INT8 | 10 GB |
48 | | INT4 | 6 GB |
49 |
50 | ### Environment Setup
51 |
52 | Install the requirements with pip: `pip install -r requirements.txt`. `transformers` library version is recommended to be `4.27.1`, but theoretically any version no lower than `4.23.1` is acceptable.
53 |
54 | In addition, if you need to run the quantified model on the CPU, you also need to install `gcc` and `openmp`. Most Linux distributions are installed by default. For Windows, you can check `openmp` when installing [TDM-GCC](https://jmeubank.github.io/tdm-gcc/). On Windows testing environment, the `gcc` version is `TDM-GCC 10.3.0`, and on Linux is `gcc 11.3.0`.
55 |
56 | ### Usage
57 |
58 | Generate dialogue with the following code
59 |
60 | ```python
61 | >>> from transformers import AutoTokenizer, AutoModel
62 | >>> tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
63 | >>> model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
64 | >>> model = model.eval()
65 | >>> response, history = model.chat(tokenizer, "你好", history=[])
66 | >>> print(response)
67 | 你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。
68 | >>> response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
69 | >>> print(response)
70 | 晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:
71 |
72 | 1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。
73 | 2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。
74 | 3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。
75 | 4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。
76 | 5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。
77 | 6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。
78 |
79 | 如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。
80 | ```
81 | The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can add the `revision="v0.1.0"` parameter in the `from_pretrained` call. `v0.1.0` is the latest version number. For a complete list of versions, see [Change Log](https://huggingface.co/THUDM/chatglm-6b#change-log).
82 |
83 | ### Load the model locally
84 | The above code will automatically download the model implementation and checkpoints by [transformers](https://github.com/huggingface/transformers). The full model implementation can be found at [Hugging Face Hub](https://huggingface.co/THUDM/chatglm-6b). If your network environment is poor, downloading model parameters may take a long time or even fail. At this point, you can download the model to the local first, and then load it from the local.
85 |
86 | To download models from Hugging Face Hub, you need to [install Git LFS](https://docs.github.com/zh/repositories/working-with-files/managing-large-files/installing-git-large-file-storage) , then run
87 | ```Shell
88 | git clone https://huggingface.co/THUDM/chatglm-6b
89 | ```
90 |
91 | After downloading the model locally, replace `THUDM/chatglm-6b` in the above code with the path of your local `chatglm-6b` folder to load the model locally.
92 |
93 | **Optional**: The implementation of the model is still in development. If you want to fix the used model implementation to ensure compatibility, you can execute
94 | ```Shell
95 | git checkout v0.1.0
96 | ```
97 |
98 | ## Demo & API
99 |
100 | We provide a Web demo based on [Gradio](https://gradio.app) and a command line demo in the repo. First clone our repo with:
101 |
102 | ```shell
103 | git clone https://github.com/THUDM/ChatGLM-6B
104 | cd ChatGLM-6B
105 | ```
106 |
107 | ### Web Demo
108 |
109 | 
110 |
111 | Install Gradio `pip install gradio`,and run [web_demo.py](web_demo.py):
112 |
113 | ```shell
114 | python web_demo.py
115 | ```
116 |
117 | The program runs a web server and outputs the URL. Open the URL in the browser to use the web demo.
118 |
119 | Thanks to [@AdamBear](https://github.com/AdamBear) for implementing a web demo based on Streamlit, see [#117](https://github.com/THUDM/ChatGLM-6B/pull/117 ).
120 |
121 | #### CLI Demo
122 |
123 | 
124 |
125 | Run [cli_demo.py](cli_demo.py) in the repo:
126 |
127 | ```shell
128 | python cli_demo.py
129 | ```
130 |
131 | The command runs an interactive program in the shell. Type your instruction in the shell and hit enter to generate the response. Type `clear` to clear the dialogue history and `stop` to terminate the program.
132 |
133 | ## API Deployment
134 | First install the additional dependency `pip install fastapi uvicorn`. The run [api.py](model_api.py) in the repo.
135 | ```shell
136 | python model_api.py
137 | ```
138 | By default the api runs at the`8000`port of the local machine. You can call the API via
139 | ```shell
140 | curl -X POST "http://127.0.0.1:8000" \
141 | -H 'Content-Type: application/json' \
142 | -d '{"prompt": "你好", "history": []}'
143 | ```
144 | The returned value is
145 | ```shell
146 | {
147 | "response":"你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。",
148 | "history":[["你好","你好👋!我是人工智能助手 ChatGLM-6B,很高兴见到你,欢迎问我任何问题。"]],
149 | "status":200,
150 | "time":"2023-03-23 21:38:40"
151 | }
152 | ```
153 |
154 | ## Deployment
155 |
156 | ### Quantization
157 |
158 | By default, the model parameters are loaded with FP16 precision, which require about 13GB of GPU memory. It your GPU memory is limited, you can try to load the model parameters with quantization:
159 |
160 | ```python
161 | # Change according to your hardware. Only support 4/8 bit quantization now.
162 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().quantize(8).cuda()
163 | ```
164 |
165 | After 2 to 3 rounds of dialogue, the GPU memory usage is about 10GB under 8-bit quantization, and only 6GB under 4-bit quantization. As the number of dialogue rounds increases, the corresponding GPU memory consumption also increases. Due to the use of relative position encoding, ChatGLM-6B theoretically supports an infinitely long context-length, but the performance will gradually decline after the total length exceeds 2048 (training length).
166 |
167 | Model quantization brings a certain performance decline. After testing, ChatGLM-6B can still perform natural and smooth generation under 4-bit quantization. using [GPT-Q](https://arxiv.org/abs/2210.17323) etc. The quantization scheme can further compress the quantization accuracy/improve the model performance under the same quantization accuracy. You are welcome to submit corresponding Pull Requests.
168 |
169 | The quantization costs about 13GB of CPU memory to load the FP16 model. If your CPU memory is limited, you can directly load the quantized model, which costs only 5.2GB CPU memory:
170 | ```python
171 | # For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
172 | model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).half().cuda()
173 | ```
174 |
175 | ### CPU Deployment
176 |
177 | If your computer is not equipped with GPU, you can also conduct inference on CPU, but the inference speed is slow (and taking about 32GB of memory):
178 |
179 | ```python
180 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()
181 | ```
182 |
183 | If your CPU memory is limited, you can directly load the quantized model:
184 | ```python
185 | # For INT8-quantized model, change "chatglm-6b-int4" to "chatglm-6b-int8"
186 | model = AutoModel.from_pretrained("THUDM/chatglm-6b-int4", trust_remote_code=True).float()
187 | ```
188 |
189 | If your encounter the error `Could not find module 'nvcuda.dll'` or `RuntimeError: Unknown platform: darwin`(MacOS), please [load the model locally](README_en.md#load-the-model-locally).
190 |
191 | ### GPU Inference on Mac
192 | For Macs (and MacBooks) with Apple Silicon, it is possible to use the MPS backend to run ChatGLM-6B on the GPU. First, you need to refer to Apple's [official instructions](https://developer.apple.com/metal/pytorch) to install PyTorch-Nightly.
193 |
194 | Currently you must [load the model locally](README_en.md#load-the-model-locally) on MacOS. Change the code to load the model from your local path, and use the mps backend:
195 | ```python
196 | model = AutoModel.from_pretrained("your local path", trust_remote_code=True).half().to('mps')
197 | ```
198 | Then you can use GPU-accelerated model inference on Mac.
199 |
200 | ### Multi-GPU Deployment
201 | If you have multiple GPUs, but the memory size of each GPU is not sufficient to accommodate the entire model, you can split the model across multiple GPUs.
202 |
203 | First, install accelerate: `pip install accelerate`, and then load the model using the following method:
204 | ```python
205 | from utils import load_model_on_gpus
206 | model = load_model_on_gpus("THUDM/chatglm-6b", num_gpus=2)
207 | ```
208 |
209 | This will deploy the model onto two GPUs for inference. You can change `num_gpus` to the number of GPUs you want to use. By default, the model is split evenly, but you can also specify the `device_map` parameter to customize the splitting.
210 |
211 | ## Parameter-efficient Tuning
212 | Parameter-efficient tuning based on [P-tuning v2](https://github.com/THUDM/P-tuning-v2). See [ptuning/README.md](ptuning/README.md) for details on how to use it.
213 |
214 | ## Update
215 | **[2023/04/16]** Added INT8 quantized model [ChatGLM-6B-INT8](https://huggingface.co/THUDM/chatglm-6b-int8). Added multi-GPU deployment (thanks to [@Cherrysaber](https://github.com/Cherrysaber)).
216 |
217 | **[2023/04/06]** Improve the web demo interface (thanks to [@tuteng0915](https://github.com/tuteng0915)). Remove the image tokens in the embedding layer to reduce the memory usage (need to update the model files `pytorch_model-00001-of-00008.bin` and `pytorch_model-00008-of-00008.bin`, thanks to [@silverriver](https:/ /github.com/silverriver) for proposing the idea). Removed dependency on `icetk` (need to update model file `ice_text.model`).
218 |
219 | **[2023/03/31]** Added a parameter-efficient tuning implementation based on [P-Tuning-v2](https://github.com/THUDM/P-tuning-v2). The minimum INT4 quantization level only needs 7GB GPU memory is enough for model tuning. See [Parameter-efficient tuning method](ptuning/README.md) for details.
220 |
221 | **[2023/03/23]** Add API deployment, thanks to [@LemonQu-GIT](https://github.com/LemonQu-GIT). Add embedding-quantized model [ChatGLM-6B-INT4-QE](https://huggingface.co/THUDM/chatglm-6b-int4-qe). Add support for GPU inference on Mac with Apple Silicon.
222 |
223 | **[2023/03/19]** Add streaming output function `stream_chat`, already applied in web and CLI demo. Fix Chinese punctuations in output. Add quantized model [ChatGLM-6B-INT4](https://huggingface.co/THUDM/chatglm-6b-int4).
224 |
225 | ## ChatGLM-6B Examples
226 |
227 | The following are some Chinese examples with `web_demo.py`. Welcome to explore more possibility with ChatGLM-6B.
228 |
229 | Self Cognition
230 |
231 | 
232 |
233 |
234 |
235 | Outline
236 |
237 | 
238 |
239 |
240 |
241 | Ad
242 |
243 | 
244 |
245 | 
246 |
247 |
248 |
249 | Email
250 |
251 | 
252 |
253 | 
254 |
255 |
256 |
257 | Information Extraction
258 |
259 | 
260 |
261 |
262 |
263 | Role Play
264 |
265 | 
266 |
267 |
268 |
269 | Comparison
270 |
271 | 
272 |
273 |
274 |
275 | Travel Guide
276 |
277 | 
278 |
279 |
280 |
281 | ## License
282 |
283 | This repository is licensed under the [Apache-2.0 License](LICENSE). The use of ChatGLM-6B model weights is subject to the [Model License](MODEL_LICENSE)。
284 |
285 | ## Citation
286 |
287 | If you find our work useful, please consider citing the following papers:
288 |
289 | ```
290 | @inproceedings{
291 | zeng2023glm-130b,
292 | title={{GLM}-130B: An Open Bilingual Pre-trained Model},
293 | author={Aohan Zeng and Xiao Liu and Zhengxiao Du and Zihan Wang and Hanyu Lai and Ming Ding and Zhuoyi Yang and Yifan Xu and Wendi Zheng and Xiao Xia and Weng Lam Tam and Zixuan Ma and Yufei Xue and Jidong Zhai and Wenguang Chen and Zhiyuan Liu and Peng Zhang and Yuxiao Dong and Jie Tang},
294 | booktitle={The Eleventh International Conference on Learning Representations (ICLR)},
295 | year={2023},
296 | url={https://openreview.net/forum?id=-Aw0rrrPUF}
297 | }
298 | ```
299 |
300 | ```
301 | @inproceedings{du2022glm,
302 | title={GLM: General Language Model Pretraining with Autoregressive Blank Infilling},
303 | author={Du, Zhengxiao and Qian, Yujie and Liu, Xiao and Ding, Ming and Qiu, Jiezhong and Yang, Zhilin and Tang, Jie},
304 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)},
305 | pages={320--335},
306 | year={2022}
307 | }
308 | ```
309 |
--------------------------------------------------------------------------------
/cli_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import signal
4 | from transformers import AutoTokenizer, AutoModel
5 |
6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
7 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
8 | model = model.eval()
9 |
10 | os_name = platform.system()
11 | clear_command = 'cls' if os_name == 'Windows' else 'clear'
12 | stop_stream = False
13 |
14 |
15 | def build_prompt(history):
16 | prompt = "欢迎使用 audioConversation-ChatGLM 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
17 | for query, response in history:
18 | prompt += f"\n\n用户:{query}"
19 | prompt += f"\n\nChatGLM-6B:{response}"
20 | return prompt
21 |
22 |
23 | def signal_handler(signal, frame):
24 | global stop_stream
25 | stop_stream = True
26 |
27 |
28 | def main():
29 | history = []
30 | global stop_stream
31 | print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
32 | while True:
33 | query = input("\n用户:")
34 | if query.strip() == "stop":
35 | break
36 | if query.strip() == "clear":
37 | history = []
38 | os.system(clear_command)
39 | print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
40 | continue
41 | count = 0
42 | for response, history in model.stream_chat(tokenizer, query, history=history):
43 | if stop_stream:
44 | stop_stream = False
45 | break
46 | else:
47 | count += 1
48 | if count % 8 == 0:
49 | os.system(clear_command)
50 | print(build_prompt(history), flush=True)
51 | signal.signal(signal.SIGINT, signal_handler)
52 | os.system(clear_command)
53 | print(build_prompt(history), flush=True)
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/examples/ad-writing-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/ad-writing-2.png
--------------------------------------------------------------------------------
/examples/blog-outline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/blog-outline.png
--------------------------------------------------------------------------------
/examples/comments-writing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/comments-writing.png
--------------------------------------------------------------------------------
/examples/email-writing-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/email-writing-1.png
--------------------------------------------------------------------------------
/examples/email-writing-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/email-writing-2.png
--------------------------------------------------------------------------------
/examples/information-extraction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/information-extraction.png
--------------------------------------------------------------------------------
/examples/role-play.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/role-play.png
--------------------------------------------------------------------------------
/examples/self-introduction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/self-introduction.png
--------------------------------------------------------------------------------
/examples/sport.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/sport.png
--------------------------------------------------------------------------------
/examples/tour-guide.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/examples/tour-guide.png
--------------------------------------------------------------------------------
/improve/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B Badcase 反馈计划
2 | ChatGLM-6B 自3月14号发布以来受到了广大开发者和用户的喜爱,截至4月22号 GitHub 的 star 数达到 2 万,各个渠道模型的累计下载量过 100 万,并连续 12 天居 Hugging Face 全球大模型下载榜第一名。 与此同时,有一批基于 ChatGLM-6B 的[优秀开源项目](https://github.com/THUDM/ChatGLM-6B)出现,在各个平台也引起了广泛好评和关注。此外,基于 GLM-130B 的千亿对话模型 ChatGLM 也自3月14号开始了第一阶段的邀请制内测,得到了内测用户的好评和支持。谢谢大家对 ChatGLM 及其 6B 开源版本的大力支持!
3 |
4 | 接下来,我们想邀请大家一起推动 ChatGLM-6B 的进一步提升,一起推动模型的发展。尽管ChatGLM-6B已初具符合人类偏好的问答对话能力,在相当多的指令和问题上,其回答仍存在不理解复杂指令和任务含义,缺乏领域概念理解,事实性错误,生成有害内容,对话上下文不一致等诸多问题。尽管我们提供的[微调代码](https://github.com/THUDM/ChatGLM-6B/tree/main/ptuning)能够让用户通过自主训练修复部分问题,但因为神经网络的[灾难性遗忘](https://picture.iczhiku.com/weixin/message1587593113355.html)问题,微调后的模型往往会失去在通用领域的对话能力或者因数据较少而缺乏泛化能力。为了解决这些问题,进一步提升 ChatGLM-6B 的能力,我们启动了 ChatGLM-6B Badcase 反馈计划。
5 |
6 | 具体来说,对于在使用 ChatGLM-6B 过程中遇到的表现不佳的Badcase对应的具体指令和提问,您可以修改或从头撰写您认为合适的正确答案,并反馈给我们改进 ChatGLM-6B。**请您确保提交的数据不包含任何个人信息、商业秘密或可能危害国家安全、侵害第三方知识产权的内容。** 我们会定期(每2-4周)对数据的有用性与正确性进行筛选,将筛选通过的数据,与通用域的对话数据一起加入到模型训练中,并**更新发布开源的模型参数**。**您提供的数据无论是否筛选通过,除非获得您的许可或根据国家法律规定和监管要求外,我们不会将您提供的数据对外公开。**
7 |
8 | 您提供的数据如被筛选通过,您将同时优先获得最新版本的 ChatGLM-6B 模型的体验资格。此外,如果您愿意,您的用户名还将出现在 ChatGLM-6B Github页面的数据贡献者名单中,以此感谢您对推进大模型开源事业发展的帮助。您的支持和建议将为我们优化大模型提供源源不断的动力,在此向您表达我们最深的敬意与感谢!
9 |
10 | 如果您希望参与反馈,请填写[问卷](https://www.wjx.cn/vm/rAoGx9X.aspx#)并按照具体要求上传。提交的数据为 jsonline 格式,每行的内容为
11 | ```json lines
12 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿x s裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
13 | ```
14 | 其中,`prompt` 部分为模型的输入,`response` 部分为期望的模型输出。为了保证模型的性能,请在输入中尽可能详细地表述任务的类型和期望的输出格式。针对某项具体的任务,为了使模型得到充分的训练,一般需要100条左右的训练数据。
15 |
--------------------------------------------------------------------------------
/improve/data_sample.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*风格#性感*图案#线条*裤型#阔腿裤", "response": "宽松的阔腿裤这两年真的吸粉不少,明星时尚达人的心头爱。毕竟好穿时尚,谁都能穿出腿长2米的效果宽松的裤腿,当然是遮肉小能手啊。上身随性自然不拘束,面料亲肤舒适贴身体验感棒棒哒。系带部分增加设计看点,还让单品的设计感更强。腿部线条若隐若现的,性感撩人。颜色敲温柔的,与裤子本身所呈现的风格有点反差萌。"}
2 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*图案#条纹*图案#线条*图案#撞色*裙型#鱼尾裙*裙袖长#无袖", "response": "圆形领口修饰脖颈线条,适合各种脸型,耐看有气质。无袖设计,尤显清凉,简约横条纹装饰,使得整身人鱼造型更为生动立体。加之撞色的鱼尾下摆,深邃富有诗意。收腰包臀,修饰女性身体曲线,结合别出心裁的鱼尾裙摆设计,勾勒出自然流畅的身体轮廓,展现了婀娜多姿的迷人姿态。"}
3 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#粉红色*图案#字母*图案#文字*图案#线条*衣样式#卫衣*衣款式#不规则", "response": "宽松的卫衣版型包裹着整个身材,宽大的衣身与身材形成鲜明的对比描绘出纤瘦的身形。下摆与袖口的不规则剪裁设计,彰显出时尚前卫的形态。被剪裁过的样式呈现出布条状自然地垂坠下来,别具有一番设计感。线条分明的字母样式有着花式的外观,棱角分明加上具有少女元气的枣红色十分有年轻活力感。粉红色的衣身把肌肤衬托得很白嫩又健康。"}
4 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*材质#雪纺*风格#清新*裙型#a字*裙长#连衣裙", "response": "踩着轻盈的步伐享受在午后的和煦风中,让放松与惬意感为你免去一身的压力与束缚,仿佛要将灵魂也寄托在随风摇曳的雪纺连衣裙上,吐露出微妙而又浪漫的清新之意。宽松的a字版型除了能够带来足够的空间,也能以上窄下宽的方式强化立体层次,携带出自然优雅的曼妙体验。"}
5 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#棉*颜色#蓝色*风格#潮*衣样式#polo*衣领型#polo领*衣袖长#短袖*衣款式#拼接", "response": "想要在人群中脱颖而出吗?那么最适合您的莫过于这款polo衫短袖,采用了经典的polo领口和柔软纯棉面料,让您紧跟时尚潮流。再配合上潮流的蓝色拼接设计,使您的风格更加出众。就算单从选料上来说,这款polo衫的颜色沉稳经典,是这个季度十分受大众喜爱的风格了,而且兼具舒适感和时尚感。"}
6 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#h*材质#蚕丝*风格#复古*图案#条纹*图案#复古*图案#撞色*衣样式#衬衫*衣领型#小立领", "response": "小女人十足的条纹衬衣,缎面一点点的复古,还有蓝绿色这种高级气质复古色,真丝材质,撞色竖条纹特别的现代感味道,直h型的裁剪和特别的衣长款式,更加独立性格。双层小立领,更显脸型。"}
7 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#网纱*颜色#粉红色*图案#线条*图案#刺绣*裙腰型#高腰*裙长#连衣裙*裙袖长#短袖*裙领型#圆领", "response": "这款连衣裙,由上到下都透出一丝迷人诱惑的女性魅力,经典圆领型,开口度恰好,露出你的迷人修长的脖颈线条,很是优雅气质,短袖设计,在这款上竟是撩人美貌,高腰线,散开的裙摆,到小腿的长度,遮住了腿部粗的部分,对身材有很好的修饰作用,穿起来很女神;裙身粉红色花枝重工刺绣,让人一眼难忘!而且在这种网纱面料上做繁复图案的绣花,是很考验工艺的,对机器的要求会更高,更加凸显我们的高品质做工;"}
8 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#文字*图案#印花*衣样式#卫衣", "response": "一款非常简洁大方的纯色卫衣,设计点在于胸前的“”的中文字印花,新颖特别,让人眼前一亮。简单又吸睛的款式,而且不失时髦感,很适合个性年轻人。"}
9 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#黑色*颜色#灰色*颜色#姜黄色*风格#休闲*图案#线条*图案#撞色*衣样式#毛衣*衣袖型#落肩袖", "response": "看惯了灰色的冷淡和黑色的沉闷感,来一点醒目的彩色增添点活力吧。亮眼又吸睛的姜黄色色调,嫩肤显白非常的有设计感。趣味的撞色和宽松的版型相交辉映,修饰身形小缺点的同时,时尚又百搭。优雅的落肩袖,轻松修饰肩部线条,让毛衣上身凸显出一丝慵懒随性的休闲感,时尚魅力尽显。"}
10 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*风格#潮*图案#印花*图案#撞色*衣样式#衬衫*衣领型#圆领*衣长#中长款*衣长#常规*衣袖长#无袖", "response": "黑与白,两种最极端的颜色却轻松搭配成了经典,就像此款衬衣,无需过多装饰,仅色调就足够醒目个性,受潮所喜欢。做了无袖中长款的样式,走路带风的感觉着实不错,圆领的设计,不是常规的衬衫领,少了点正式反而有种休闲感觉,适合孩子们穿着。后背大面积撞色印花装点,是时尚潮流的象征,也让衣衣不至于单调,轻松就能穿出彩。"}
11 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#街头*风格#休闲*风格#朋克*图案#字母*图案#文字*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#对称", "response": "个性休闲风的连帽卫衣造型时髦大方,宽松的版型剪裁让肉肉的小宝贝也可以穿着,保暖的连帽设计时刻给予宝贝温柔的呵护,袖子和后背别致时髦的字母印花点缀,满满的街头元素融入,演绎休闲朋克风,对称的小口袋美观大方,方便放置更多的随身物品。"}
12 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#链条", "response": "简单大气的设计,不费吹灰之力就能搭配的时髦范儿。时尚的配色一点都不觉得平淡了,有种浑然天成的大气感。强调了整体的装饰,和谐又不失个性,搭配裤装帅气十足,搭配裙子精致优雅。链条和肩带的搭配让使用感更加舒服,单肩手提都好看。"}
13 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#显瘦*材质#牛仔布*颜色#深蓝色*风格#复古*图案#复古*图案#线条*裤腰型#高腰*裤口#微喇裤", "response": "深蓝色的高腰牛仔裤,修身的款式勾勒出纤细的美腿。牛仔裤的裤脚设计张开的喇叭型,巧妙地修饰了小腿的线条,洋溢着复古的年代感。"}
14 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#潮*风格#性感*图案#条纹*图案#蝴蝶结*衣样式#衬衫*衣领型#一字领*衣门襟#系带*衣款式#不对称", "response": "这是一件显得特别清新的衬衣,采用了条纹的设计,给予人一种甜美可人的气质。并且融合了别致的斜肩一字领设计,高调的展示出性感的锁骨,将迷人的香肩展现在外,性感中不失去清纯的气息。袖口处的蝴蝶结系带装饰,增添了俏皮的韵味,简洁大方。且在下摆处采用了不对称的设计,增强了视觉效果,更显潮流。"}
15 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#复古*图案#复古*裤型#直筒裤*裤款式#纽扣*裤腰型#高腰", "response": "作为基础款单品,牛仔裤也,想要呈现给大家的是——每次搭配都有新感觉。裤子经过复古做旧处理,风格鲜明,也很注重细节,连纽扣也做了统一的做旧处理,融入个性十足的磨破设计,高腰直筒basic裤型,修饰身材,穿出高挑长腿。"}
16 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*图案#线条*图案#刺绣*衣样式#针织衫*衣领型#v领", "response": "一款温暖柔软又富有弹性的针织衫,不仅可以抵御严寒侵袭,还能更好地进行搭配。v领的设计,能勾勒出迷人的天鹅颈以及衬托出娇小的脸型。宽松又别致的剪裁,能从视觉上显露纤长的下半身,起到显瘦的效果。直筒造型的袖子,修饰出优美的手臂线条,衣身上的方格刺绣,时尚又吸睛。"}
17 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#绿色*风格#清新*图案#线条*衣样式#衬衫*衣领型#翻领", "response": "绿色的衣身上镶嵌着,就是这款衬衫最大的迷人之处,“红花配绿叶”般的色调,将清新气息阐述的淋漓尽致。经典的翻领更是贴心,修饰颈部线条的同时,尽显精致干练的气质,出街轻松凹造型。"}
18 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*图案#字母*图案#文字*图案#印花*图案#撞色*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "这款外套采用了撞色拉链织带以及字母印花设计。这两种元素的融入使外套不会显得过于单调沉闷,吸睛而亮眼,充满年轻与朝气感,非常减龄。"}
19 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*版型#h*风格#复古*图案#复古*图案#刺绣*裙长#连衣裙*裙袖长#长袖*裙领型#翻领*裙衣门襟#单排扣", "response": "本款连衣裙整体采用h型的轮廓设计,藏肉显瘦,不挑身材,适合各种身形的人穿着。小翻领的领口设计,使得本款连衣裙穿在身上看起来十分的精神帅气,具有青春活力。单排扣的衣门襟设计,又给本款连衣裙带来了一丝的复古味道。裙身上的刺绣花朵装饰,使得本款连衣裙不显得单调,富有层次感,上身给人一种独特的时尚魅力。长袖的设计,更加的贴合手臂曲线,上身更加的舒适贴身。"}
20 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#粉色*风格#清新*衣样式#外套*衣样式#西装*衣门襟#双排扣", "response": "这款外套设计成西装的版型,彰显经典优雅的气质,结合了粉色又添清新气息,甜美百搭时尚感满满。利落的版型简洁流畅,亮色双排扣更添精致感。"}
21 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#线条*衣样式#风衣*衣样式#外套*衣门襟#拉链*衣款式#拉链*衣款式#松紧带*衣款式#连帽*衣款式#收腰", "response": "选自品牌江南布衣的一款女士长风衣外套,选用轻薄的,穿着灵活毫无压力。直筒版型简洁利落,长过膝盖的长度穿着个性十足,连帽宽大有型,富有活力,拉链开合,拉上拉链有一丝酷劲,敞开穿则更休闲,连帽领翻开修饰颈部线条。松紧带收腰设计,低调的分割上下比例,打造显高小心机。"}
22 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#街头*风格#简约*图案#刺绣*裤长#短裤*裤款式#钉珠*裤口#毛边", "response": "又到了光腿穿裙子和短裤的时候了,BRAND的这款短裤,采用柔软透气的纯棉牛仔面料,穿着舒适无束缚感。而简约的版型加入了精美的刺绣和钉珠装饰,提升了整体的品质感,显得精美而又立体饱满。搭配下摆的毛边装饰,散发出不羁的街头感。"}
23 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*颜色#黑色*图案#条纹*衣样式#衬衫*衣领型#翻领*衣门襟#系带*衣款式#拼接*衣款式#露肩", "response": "一款老鹰图案露肩衬衫,露肩系带的设计,少女感十足。老鹰图案的设计,更添几分趣味感。条纹面料和牛仔面料的拼接设计,给人一种风度的层次感。小翻领的设计十分的精致,搭配一件黑色打底裤也吸晴万分。"}
24 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*裙型#百褶*裙长#半身裙*裙款式#拼接*裙款式#腰带", "response": "一款颇有设计感的半身裙,单侧雪纺百褶的拼接设计,规整排列的层次感带来立体效果,增加了裙身的廓形,行走间更是带来柔美的灵动气息,轻而易举穿出优雅的轻熟风,呈现十足的女人味来。同面料延伸处理的半固定腰带,可以自然的垂落下来,也算是为整体打造造型亮点,彰显你独特的时尚品味,迎合早春对轻盈的追求。"}
25 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#性感*裙型#包臀裙*裙型#鱼尾裙", "response": "修身包臀版型结合性感鱼尾裙摆设计,彰显婉约优雅风情之余,为整体注入几分俏皮灵动气息。且下摆辅以律动感摺裥元素,更烘托出女性浪漫精致的一面。"}
26 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*颜色#绿色*图案#线条*裙长#连衣裙*裙领型#v领*裙款式#勾花镂空", "response": "连衣裙可以让你在旋转与跳跃间,都散发出万种风情,受到了万千女性的喜爱。这款连衣裙选用绿色调,既散发出活力气息,又增添了高雅的气质。而镂空的钩花设计,则为其增添了浪漫的风情,同时更显美观与时尚。再加上v领的设计,不仅映衬出精致的脸颊,还打造出优美的颈部线条。"}
27 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#淑女*衣样式#毛衣*衣领型#高领", "response": "高领毛衣一直是网红妹子,因为穿着它有一种淑女甜美气质。它最大的亮点在于它的高领设计和花边装饰。在淑女干练的气质基础上又增加了一些少女的甜美气息,穿着非常有型,最佳搭配小白鞋。"}
28 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*裤长#九分裤*裤型#阔腿裤*裤款式#拉链*裤腰型#高腰*裤口#开叉", "response": "九分裤长,把妹子的拉长了腿的比例,配合高腰设计,瞬间显得妹子的腿长了很多,一下子自信满满啦。采用侧面隐藏拉链设计,穿脱方便又舒适。设计感十足的开叉裤脚,身上的摩登浓了。这个春天妹子的腿型,就交给阔腿裤啦。"}
29 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "经典的套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
30 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#黑白*风格#复古*风格#文艺*图案#格子*图案#复古*衣样式#衬衫*衣领型#翻领*衣门襟#套头*衣款式#纽扣", "response": "套头翻领衬衫与黑白格纹元素组合,一直以来的气场经久不衰。而采用复古精致的纽扣装点的半门襟设计,简单的小细节处理,彰显出浓浓的文艺气息。"}
31 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#卡通*裙长#连衣裙", "response": "传奇而又经典的卡通形象,米老鼠似乎已经成为了孩童风格的一种标志,大小不一的头像以及奇趣的表情设计。满版的点缀风格让整个连衣裙洋溢着独特的天真气质,加之面料小口袋的点缀,小小的造型呈现出灵巧而又可爱的格调,让宝贝俏皮萌动。"}
32 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#条纹*图案#刺绣*裙型#背带裙*裙下摆#毛边", "response": "假两件版型的设计,给人一种错觉,大大增添自身时髦感。毛边裙摆的采用,看起来活力十足。设计师解决了以往穿脱不方便的问题,应用的可调节背带设计,非常的人性化。裙子上的花朵刺绣图案,看起来也栩栩如生,同时也展示出了精湛的做工手艺。为了与女人自身清纯的一面形成呼应,应用的条纹图案非常完美。"}
33 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#街头*风格#休闲*裤长#五分裤*裤腰型#松紧腰", "response": "这款休闲五分裤,采用亲肤软牛仔,洗水磨白形成深浅对比,更加个性。大弹力松紧腰,舒适贴合,一点都不紧勒。裤子门襟时尚的设计,为细节加分。立体双贴袋,腰间系带的点缀更吸睛。精致的裁剪,或是干练整洁的走线和宽松版型,是对街头的描写。"}
34 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*图案#条纹*图案#蕾丝*衣样式#衬衫*衣领型#立领*衣款式#荷叶边", "response": "条纹衬衫是引领时尚圈的常青树,尤其给人舒适感官享受的蓝白条纹,更是深得时尚icon的喜爱。加之搭载经典的立领秀出纤长的玉颈,更显气质优雅。肩膀上饰有薄薄的蕾丝,打破了条纹衬衫的干练,更添别样风情。荷叶边的蔓延更显气质甜美,自然吸睛无数。"}
35 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*风格#休闲*图案#线条*裙型#牛仔裙*裙长#半身裙", "response": "牛仔半身裙作为时尚宠儿,一直被很多潮人捧在手心,zui妙的莫过于它的时髦感以及百搭易驾驭的属性。裙身设计了自然的腰型,拉长腿部线条,让小仙女们感受到大长腿。以及两侧插袋的造型,显得比较随意,休闲的感觉。"}
36 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#刺绣*衣样式#卫衣*衣袖型#罗纹袖口", "response": "乍一看很平凡的纯色系卫衣,暗藏的刺绣玄机,就足够把时髦的张力表现得不凡。很有包容性的廓形,舒适的罗纹收口,宽大的样子却依旧既定的风格,让你的潇洒随性表现得收放自如。呆萌查理的袖间刺绣,极简的漫画笔触巧塑生动有趣的风格。"}
37 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*衣样式#衬衫*衣样式#风衣*衣款式#抽绳", "response": "风衣带有一种随性大气的感觉,在春风拂面的日子里能衬托出你的气质。草绿色的衣身配色,与与自身清纯干净的性格形成了呼应。具有一定实用性的下摆抽绳,可以让你任意的变换风格。抛弃了衬衫领的设计应用的设计,更能将你帅气的一面展示出来。"}
38 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
39 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "优质挺括的面料。包臀版型,长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚。"}
40 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#职场*图案#线条*裙型#包臀裙*裙下摆#开叉", "response": "长度在膝盖往下一点,显得利落而大方,与生俱来的气场感。适合职场女性,包臀裙的优势在于凸显腰身线条,侧边开叉的设计不仅让整体造型更具曾层次感,也增添一份恰到好处的妩媚"}
41 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*风格#复古*风格#清新*图案#条纹*图案#复古*衣样式#针织衫*衣样式#开衫*衣长#常规*衣款式#拼接*衣款式#纽扣*衣款式#罗纹", "response": "慵懒气质的针织开衫,充满了复古的情调,奶奶级的麻花编织手法,充满立体感的同时保暖效果也是满分。下摆的罗纹拼接,让针织衫回暖性更棒。活泼的条纹拼接,跳脱出常规配色,清新色调的选用,更加衬托出肌肤的雪白。精致的纽扣点缀,反光的质感让针织衫充满现代感。"}
42 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#复古*图案#蝴蝶结*图案#复古*图案#波点*衣样式#衬衫*衣领型#立领*衣门襟#系带*衣款式#木耳", "response": "【说】衬衫,大波点气质复古从立领上延伸的长系带,可轻松绑成蝴蝶结,甜美感加分采用打缆工艺的松紧袖口边边处的木耳很可爱"}
43 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#简约*风格#青春*图案#字母*图案#文字*裙型#网纱裙*裙袖长#无袖*裙领型#圆领", "response": "大气的圆领舒适贴合,彰显出女孩儿精神的气质。无袖的款式与圆领相迎合,简约的同时又不失时尚风采。前身由可爱蝴蝶图案点缀,亮丽的字母映衬其上,诉说着一丝精美感。橙色网纱裙摆造型优雅唯美,与上身的图案相呼应,十分富有青春的气息,伴随着步伐的行走间,带出一丝别致浪漫的风情。"}
44 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#丝绒*风格#复古*图案#复古*衣样式#雪纺衫*衣袖型#喇叭袖*衣款式#木耳边*衣款式#飘带*衣款式#荷叶边", "response": "这款雪纺衫,采用具有复古韵味的荷叶边元素,加上丝绒质感的加长飘带,洋溢着浪漫古典的韵味。两侧镶有包扣,和立体木耳边装饰,大大提升时髦指数。而流线型喇叭袖设计,充满灵动质感,为造型平添活力。"}
45 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*版型#显瘦*裙下摆#荷叶边*裙腰型#高腰*裙长#半身裙", "response": "很简洁百搭的一款半裙,裙身荷叶边设计,飘逸灵动,上身更显层次感丰富。高腰造型,版型优良,衬显修长双腿。裙子做的比较宽松,包容性敲好,遮肉效果棒棒的。非常的显瘦哦,选用精品梭织面料,垂感好,肌理细致,上身敲舒服哟。"}
46 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#青春*风格#性感*图案#线条*裙下摆#开叉*裙长#连衣裙*裙领型#翻领*裙款式#腰带*裙款式#衬衫式", "response": "设计师以衬衫式的创作灵感,巧妙地搬运到连衣裙身上,中性又不失性感;时尚小翻领设计,巧妙衬托颈部线条,彰显青春派的艺术时尚,小资派的精彩演绎。耳目一新的双腰带设计,既突出了腰线又感觉很前卫;下面走心的大开叉设计,更能激发人的好奇心,营造出无人超越的高级性感,只需一眼就令人。"}
47 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#宽松*风格#性感*图案#印花*裙下摆#荷叶边*裙长#连衣裙*裙袖型#灯笼袖", "response": "这款连衣裙走的是性感大方的风格路线,展现出你的大大咧咧的性情,非常的有趣。选用了宽松的版型,配合星空印花的图案,塑造出新颖有趣,不失活力四射的印象感。荷叶边的裙摆设计,突显出飘逸性感的一面。配合灯笼袖的袖型细节,体现出的一面。"}
48 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*材质#水洗*颜色#浅色*风格#休闲*风格#性感*图案#线条*裙型#牛仔裙*裙型#直筒裙*裙下摆#开叉*裙下摆#毛边*裙腰型#高腰", "response": "浅色水洗效果牛仔裙,高腰设计融合修身直筒廓形,凸显纤细腰部和迷人翘臀,美化勾勒性感身材曲线。正面开叉细节有效拉长腿部线条,灵动性感。磨毛边下摆设计,带来休闲随性气息。"}
49 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#休闲*图案#条纹*图案#印花*衣样式#卫衣*衣款式#连帽*衣款式#罗纹", "response": "这款连帽卫衣自带休闲魅力,将杜嘉班纳的品牌标志以印花的形式装饰在衣身前幅,展现出华丽不失看点的视觉效果,每时每刻都在彰显不凡品味。罗纹条纹袖口和下摆,不仅能使卫衣更帅气惹眼,还能为整体增加一股前卫之风。"}
50 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#简约*图案#卡通*图案#蝴蝶结*图案#印花*衣样式#衬衫", "response": "大面积的卡通兔子印花,童趣满满,再加上领口的蝴蝶结装饰织带。充满童趣的同时又不失小女生的甜美气息,相当减龄。这款衬衫选用真丝面料,真丝面料不仅轻薄,而且柔滑、亲肤,就好像人的第二层肌肤般带给你清凉舒适的穿着感觉。合身的版型,裁剪得干净利落,简约又不失时尚气息,打造干练的气场。这款衬衫日常十分百搭,不仅可以与其他服饰搭配,作为一件单品也十分出彩。"}
51 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*材质#水洗*风格#复古*风格#简约*图案#复古*图案#线条*裤长#九分裤*裤款式#不规则*裤口#毛边", "response": "misssixty的这款单品延续经典的九分牛仔裤版型,结合贴合身形的剪裁,展现出柔美修长的腿部线条;不同的位置做了不同程度的水洗复古工艺,使得裤身更加立体厚重;此外,裤脚处采用了微微不规则的毛边剪裁,为简约的整体注入一丝随性之感;再加上相互呼应的翅膀状图案点缀,瞬间带来一丝浪漫唯美的味道。"}
52 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#纯色*图案#纯色*图案#线条*衣样式#卫衣*衣领型#圆领*衣袖型#收口*衣门襟#套头*衣款式#螺纹", "response": "使用经典的螺纹圆领来展开设计,将衣型打造成套头卫衣的款式,穿着时轻松收口,将颈部线条修饰出挺拔优美的的效果,让穿着更加具有精气神。衣身以纯色作为主色调,配上经典的小企鹅logo,将正面点缀,它拥有一个俏皮的小蝴蝶领结,充满细节感,使得衣身吸睛耀眼。"}
53 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*裤长#九分裤*裤型#直筒裤", "response": "c小小的这样一条迷人的牛仔裤彰显出你的大气个性,它的别致直筒版型十分的高端迷人,让你吸睛十足。个性九分的版型展示出你的迷人小脚踝。它的大气牛仔材质,十分的舒适洒脱,迷人更有型。"}
54 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*颜色#纯色*风格#简约*图案#纯色*图案#线条*图案#蕾丝*衣样式#衬衫*衣领型#v领", "response": "一款简约的纯色衬衫,采用了个性的大v领,露出柔美的锁骨和颈部线条,散发出清爽迷人的气质;点缀精美的蕾丝花边装饰,波浪形的花边很有美感,增加了视觉亮点。"}
55 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#撞色*裙下摆#垂坠*裙长#连衣裙*裙袖长#无袖*裙袖型#收口*裙款式#拼接*裙款式#绑带*裙款式#波浪", "response": "来自奥芝国的推出的无袖连衣裙,精选弹力冰丝材质穿柔软垂坠性很好,适合春夏秋三季穿搭。腰部的撞色波浪纹弹力腰封拼接,并以交叉绑带式收口,修饰腰身轻松穿人大长腿。"}
56 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#针织*颜色#灰色*颜色#深蓝色*图案#线条*衣样式#毛衣", "response": "这是一款专为胖孩子设计的针织毛衣,加肥加大的立体版型,利落有型穿着合体不臃肿,穿着更加帅气显瘦;领口、袖口和下摆收紧处理使衣衣更加利落有型,久穿久洗也不易磨损和变形,颇具品质感;深蓝色的大身巧妙地加入一些灰色线条修饰活泼大方,孩子穿上它,洋溢着青春活力。"}
57 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#牛仔布*材质#网纱*风格#街头*衣样式#衬衫*衣款式#拼接*衣款式#勾花镂空*衣款式#钉珠", "response": "时髦又帅气的牛仔拼接裙,利用多材质拼接演绎刚柔并济的设计。硬朗的牛仔衬衫以镂空拼接,构造出深浅的色系变化,加上钉珠铆钉的装饰,更是玩味出十足的街头帅气。下身拼接的网纱半裙,层次细腻又丰富,两侧加入牛仔插袋呼应上身面料,带来一体感设计。"}
58 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#牛仔布*颜色#蓝色*颜色#浅蓝色*风格#性感*裙型#牛仔裙*裙型#包臀裙*裙下摆#开叉*裙款式#拼接*裙款式#纽扣", "response": "mm们拼接风呢?这款牛仔裤是非常有趣的拼接风,浅蓝色和原蓝色的牛仔拼接在一起,非常吸引眼球。在左侧的裙摆处还做了开叉设计,微微露出腿部皮肤,展现性感姿态。包臀的设计,凸显圆润的臀部。前幅一排金属纽扣,增添细节感和精致度。喜欢的mm千万不要错过~"}
59 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#蕾丝*颜色#粉色*风格#清新*图案#碎花*图案#线条*图案#蕾丝*裙型#a字*裙下摆#花边*裙领型#圆领*裙款式#飘带", "response": "清新的小碎花缀满衣身,以淡雅的粉色调为底色,焕发出甜美温婉的少女气息。简洁的圆领设计,柔化脸部线条,加上蕾丝飘带点缀,更显娇俏减龄。下摆蕾丝花边分割裙裾,转身间将浪漫挥洒。散开的a字裙摆,恰到好处遮住了臀部和腿部粗的部分,有很好的修饰作用。"}
60 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#淑女*图案#植物*图案#印花*裙型#百褶*裙长#连衣裙*裙领型#娃娃领*裙款式#拼接*裙款式#腰带", "response": "法式浪漫情怀,由这款印花连衣裙为你抒写。蝴蝶花卉印花铺陈裙身,蝴蝶翩跹BRAND花丛,浪漫迷人美如画,法式风情呼之欲出。娃娃领的设计,凸显一身柔美的淑女气质。裙摆百褶的设计,飞舞更添灵动飘逸的美。腰带拼接的设计,完美打造显瘦显高的身材比例。"}
61 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*风格#复古*风格#简约*风格#休闲*图案#复古*图案#线条*图案#印花*裙长#连衣裙", "response": "这一款雪纺连衣裙复古的小立领带来不一样的惊喜,不仅拉伸了脖颈的线条,同时衬托出娇小的脸型。衣身大大的印花很有质感,简约休闲中透露着复古精致的美丽。"}
62 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#文艺*风格#简约*风格#清新*衣样式#外套*衣门襟#拉链*衣款式#拉链", "response": "飘飘落落,暖色的布料上纷纷落落的铺着羽毛,灰常有意境的一款连衣裙。羽毛是这款连衣裙最大的亮点,色彩也丰富饱满,凸显的文艺感也灰常强烈,满满的文艺清新气息;简约大方的设计,有种不喧嚣的热烈感;凸显内敛的气质。搭大衣、棉服外套不仅保暖又灰常的有韵味,而且这款不仅做了开扣的设计,还做了隐形的小拉链!是可哺乳的款式,方便孕后哺乳穿,墙裂推荐!"}
63 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#简约*风格#青春*风格#清新*风格#性感*图案#条纹*图案#撞色*裙下摆#开叉*裙长#连衣裙*裙款式#拼接*裙款式#吊带", "response": "这款针织吊带连衣裙展现青春时尚的格调,双侧撞色条纹的拼接简约经典,散发出清新爽朗的气息,显得格外惹眼,营造出明媚动人的视觉吸引力。赋予简约的吊带裙满满的活力,开叉的剪裁性感别致,充满小女人的韵味。"}
64 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#街头*风格#青春*衣样式#t恤*衣领型#圆领", "response": "三叶草的这款体恤面料比较舒适,穿起来也能很好的透气排汗。整体的设计风格就是经典的款式,所以说是街头常年流行的必备。圆领的领口设计在穿脱时起到了方便。同时修饰脸部轮廓,更显小脸。三叶草的标志也是最为独特的品牌标识,穿出了个人的品味。"}
65 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*风格#休闲*风格#清新*衣样式#外套*衣款式#连帽", "response": "春天家以清新白色为主基调打造的这款外套,整体采用了直筒的极简剪裁配合休闲感的连帽设计,穿着在身上的舒适度较高。设计师为这款上衣做了口的袖子和下摆的处理,穿着后对于身形的修饰效果会更为出众,显得较为得体、大方。"}
66 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#军绿色*风格#复古*风格#文艺*风格#知性*风格#休闲*风格#潮*图案#复古*图案#撞色*衣样式#外套*衣样式#西装*衣领型#西装领*衣长#短款*衣袖型#插肩袖", "response": "短款西装小外套,结合了知性和休闲两种风格,在现代的潮流款式中又融入了淡淡的复古韵味。端庄典雅的军绿色衣身,带着自由舒畅的旅行感,款式上选用利落率性色西装领,宽松闲适的插肩袖,门襟选用撞色的两粒扣设计,复古文艺又简洁随性。"}
67 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "亮眼的橙红色展示出迎面而来的热情感,衬托肤色白皙红润,在宴会上气场十足。方形的镂空点缀着衣领下方,增加看点散发出小女人的妩媚感。独特的衣袖造型倾斜而下,修饰手臂线条非常修长,在举手投足间优雅又大气。收腰的版型设计修饰腰部线条更纤细,打褶的裙摆在行走时灵动十足,仿佛的精灵一般。"}
68 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*图案#线条*裙款式#勾花镂空*裙款式#收腰", "response": "裙子表面的镂空花网就使其充满了很强的设计美感,首先是肩部将落肩袖和镂空图案相结合,白皙的肌肤隐隐约约,而且能够很好的缩小肩宽比例。v型领口修饰拉长颈部线条和显得脸小。裙子做了收腰裁剪,并将腰线提高,轻松拉长下半身身材比例,裙摆也更加挺括,从而能够解决胯宽等身材烦恼。"}
69 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*材质#涤纶*衣样式#风衣*衣袖型#灯笼袖*衣款式#纽扣*衣款式#飘带", "response": "风衣在摒弃了传统的版型样式,将袖子设计成花苞型的灯笼袖,与春天搭配得恰到好处。并在袖子处装饰了四颗纽扣,采用飘带作为松紧调节,增添层次感更显个性别致。除此之外,风衣采用涤纶材质制成,垂顺感好挺括修身,结合小a字形轮廓,更显身形高挑秀美,并且让矮小个的女性也能撑起风衣的气场。"}
70 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#英伦*风格#简约*风格#休闲*图案#格子*图案#线条*衣样式#西装*衣领型#翻领*衣门襟#双排扣", "response": "这一款休闲西装简约利落的翻领,可以很好地修饰脸型和颈部线条,显脸小的同时又让脖子看上去更纤细。加上精致的格纹装饰,视觉美丽凸显英伦风。而且双排扣设计,时尚大气美观实用。"}
71 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#针织*风格#文艺*风格#休闲*风格#性感*裙长#半身裙*裙长#连衣裙*裙款式#拼接", "response": "连衣裙的灵感来自于匠人穿着围裙的状态,设计师将针织上衣与半裙结合,整体松软舒适,且不失休闲随性感。裙摆不同材质的拼接,带来丰富的层次细节,让时髦度倍增。偏暗调的配色融入文艺田园气息,显随性姿态。"}
72 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#复古*图案#复古*裙下摆#荷叶边*裙长#连衣裙", "response": "对于女孩子来说,喜爱连衣裙是与生俱来的!几乎没有问题是一条裙纸的~BRAND这款裙子整体的设计有点小复古的感觉,而且艳丽的枣红色也是复古色的代表,上身穿着十分衬肤显白哦。个性而时髦的挂脖式领口露出锁骨很是撩人,另外领口至腰间的衣身前片还加入了很有灵动感的荷叶边作为点缀,瞬间点亮了整体的造型感,由内而外散发的优雅而温柔的气质无人能挡。"}
73 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*风格#日系*风格#简约*图案#线条", "response": "很喜欢这款简单却不简约的时尚牛仔裤,在夏天可以穿出个性与时尚。整个风格比较偏向于日系的身体,任何妹子都能够轻松驾驭,最重要的是版型。穿上特别修饰腿部的线条,打造出了高挑的身材,让你看起来非常有自信的呢,这手工的工艺凸显出了无限的高级质感。"}
74 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#街头*风格#潮*裙型#a字", "response": "孕期就一定要穿的沉闷单调吗?热爱潮流的怎能束缚自己个性的心呢,这款裙子采用a字型设计,让你搭配更为轻松随意,飘逸的撞色织带设计,即刻将原本沉闷的空气也带动的活跃起来。从街头到,尽显潮流个性时尚。"}
75 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#牛仔布*颜色#浅蓝色*风格#街头*风格#休闲*裤型#直筒裤*裤款式#破洞", "response": "破洞元素已变成彰显个性的元素,这款浅蓝色牛仔裤糅合磨白磨破设计,弥漫摩登个性格调,而且破洞设计,打破裤装闷热形象,休闲时髦;直筒款巧妙糅合酷帅感与时髦感,塑造街头潮人印象。"}
76 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#雪纺*风格#知性*风格#性感*图案#线条*裤长#连体裤*裤款式#木耳边", "response": "雪纺面料的一袭连体裤,舒适的手感,轻盈的穿着,宽松的版型,让上身穿着没有束缚感。一字肩的设计,木耳的花边,显露颈部柔美的线条,与性感的锁骨,展现女性知性的一面,木耳花边的设计,显露穿着的甜美感,与少女味。高收腰的设计,拉伸腰部的曲线,提高腰线,显露穿着高挑的身姿。"}
77 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#简约*图案#线条*裤款式#口袋*裤款式#拉链", "response": "侧缝处添置有立体拉链口袋作为装饰,实用性强且兼备美观性。净色的大体外观,简约低调,大方得体,易于搭配。裤腰处植入张弛有度的弹性带,贴合腰部,适合于大多数人穿着。衣身剪裁干净利落,线条流畅。"}
78 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*颜色#白色*图案#条纹*图案#线条*衣样式#衬衫", "response": "白色的衬衫采用了百褶的袖子设计,既修饰了手臂线条,又为整体增强了设计感。背带裤是永不过时的条纹款式,加上阔腿裤的设计,更显女性身材。"}
79 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#棉*材质#牛仔布*风格#简约*风格#休闲*裤长#短裤*裤款式#破洞", "response": "选用优质的纯棉面料打造出舒适的质感,而且上身不会扎身。同时,个性破洞细节设计,增加了牛仔短裤的细节感和吸睛度。此外,简约好搭的配色,柔和你的棱角,让你看起来温柔又平易近人。适合约会等休闲场景,是你衣柜里不可或缺的时髦单品之一。"}
80 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#水洗*风格#潮*裤款式#不规则*裤口#毛边", "response": "年轻潮流的设计品味,洋气又好穿。细节相当丰富有看点,融入水洗磨白,使其充满时尚不羁的气息。裤脚前后毛边处理,配上不规则脚口,更添青春活力。"}
81 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*颜色#蓝色*风格#简约*裤型#背带裤*裤款式#纽扣", "response": "背带裤的选用天蓝色的主题,远远看上去就像是蓝色悬挂在活跃孩子的身上。简约的背带设计,可随时拆开的纽扣,让稚嫩孩子穿衣时不费吹灰之力。腰部更是搭配弹性材料缝制的腰带,不仅方便穿戴而且完美的起到了修饰作用。后背交叉背带,更是独特新颖的处理,更好更牢固的穿搭,不易滑落。"}
82 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*材质#涤纶*裤款式#拼接*裤款式#口袋", "response": "前置的口袋盖拼接,为本来单调的设计布局增加了亮点,更突出了裤子的个性化特点。加上精致的涤纶梭织面料制作,具备更加亲肤不刺激的丝质般触感,给你带来更加柔软舒适的穿着体验。其良好的透气性,有效提升了裤子的吸湿排汗性能,为你提供更加清爽舒适的体感。"}
83 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙下摆#荷叶边*裙领型#圆领*裙袖型#收口*裙款式#螺纹", "response": "此款上衣采用了经典的圆领款式打造,贴合舒适并能修饰出完美的脸型。同时螺纹的收口贴合身材更完美,在前身处采用了可爱的小狮子造型,带真的感觉,而狮子的毛发更是立体精致,显得真实又有丰富的层次。裙身的下摆处采用了荷叶边的设计,俏皮活泼更可爱。"}
84 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*版型#显瘦*材质#网纱*风格#青春*图案#印花*衣样式#衬衫*衣领型#v领*衣款式#拼接", "response": "这一款衬衫交叠v领的设计,修饰脖颈尽显女人味,宽松的廓形,穿上非常轻松有范毫不拘束,并很好的遮盖身材,非常显瘦。时尚的网纱拼接,自然美感特别出彩。精致印花,青春减龄特别活力。"}
85 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#运动*风格#性感*衣样式#西装*衣领型#一字领*衣款式#荷叶边", "response": "荷叶边能够表达出女性的优雅,BRAND的这款上衣,将荷叶边很好地运动起来。性感的一字肩设计,荷叶边从一侧手臂的手肘从前胸绕到另一侧,有着前短后长的感觉,自然垂坠很有层次感,举手投足之间,灵动而优雅。西装袖很好地融合,优雅之中透着小帅气。"}
86 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#运动*风格#休闲*风格#青春*图案#字母*图案#形状*图案#文字*图案#刺绣*图案#撞色*衣样式#卫衣*衣袖型#收口*衣款式#抽绳*衣款式#连帽", "response": "这款dolce&gabbana的连帽卫衣,撞色的字母加上桃心形状的刺绣图案令人耳目一新,举手投足间散发阳光活力少女的青春气息;连帽款式尽显帅气利落风范,细节上采用抽绳处理实用又美观,洋溢满满的运动休闲范儿;加之袖口处的收口设计别出心裁,宽松的衣身烘托出慵懒率性的格调。"}
87 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*版型#宽松*材质#牛仔布*风格#休闲*图案#字母*图案#文字*图案#线条*图案#印花*图案#撞色*裤款式#拼接*裤口#小脚", "response": "上下拼接撞色设计,吸睛十足,轻松聚焦视线,个性前卫。字母印花设计,巧添时尚细节看点,以鲜明撞色渲染,展现年轻活力气息。长袖套头轮廓,线条处理恰到好处,呼应休闲基调。宽松的版型,不挑身材,上身好看。连帽的设计美观实用,防风保暖。时尚百搭,可以搭配牛仔裤、紧身裤、休闲裤、束脚裤等。"}
88 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#灰色*风格#复古*风格#文艺*风格#青春*图案#卡通*图案#复古*衣样式#风衣*衣长#中长款", "response": "一款好看的风衣大概能为这个姹紫嫣红的春天多一份色彩,沉静的灰色上身具有非常好的效果,显得热更加内敛沉稳,有一股淡淡的复古文艺风格。而中长的版型自然下垂,修身显高又瞬间提升气场。后背的卡通图案别致可爱,更添青春气息。"}
89 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*风格#性感*衣样式#针织衫*衣款式#露肩", "response": "这一款针织衫露肩设计,风情浪漫性感迷人。略微宽松的版型舒适随意,很好的掩饰身材小小的缺陷,看起来精致高挑。加上时尚的花边下摆,错落有致视觉美丽。精致袖口,修饰手臂特别出彩。"}
90 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*裙款式#松紧带*裙款式#飘带", "response": "冷风的气质感,干净利落的feel,小露香肩有一种含蓄撩拨的趣味,袖口领口的飘带设计很是巧妙,让整个小衫更加优美,领子部分的两边肩部松紧带设计,大胆随意的穿出多种效果。让仙女们走在时尚的道路上更加自信。"}
91 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#针织*衣样式#卫衣*衣领型#圆领", "response": "针织卫衣采用了简洁的圆领设计,非常百搭,免去了你找不到搭配的烦恼。合体的剪裁设计,让你在跑步健身时轻巧灵便,活动自如,达到更好的锻炼效果。"}
92 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#宽松*颜色#白色*风格#简约*图案#蝴蝶结*图案#刺绣*衣样式#衬衫*衣袖型#喇叭袖", "response": "这一款很好穿的白色衬衫,利落的宽松版型几乎是不挑身材的,无门襟的设计也符合整体的气息。胸前做了绣花的点缀,为简约的衬衫增添了几分柔美的气质。七分的喇叭袖露出小臂,蝴蝶结的点缀显得气质更加的浪漫。"}
93 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*材质#雪纺*颜色#灰色*风格#英伦*风格#复古*图案#格子*图案#复古*裙型#百褶*裙长#半身裙*裙款式#波浪*裙款式#收腰", "response": "BRAND这款半身裙,用复古的灰色格纹,打造出十足英伦范儿。搭配百褶裙身,为整体增添层次感,穿出减龄风。同时,波浪边的收腰设计,不仅更好的修饰腰部曲线,还为整体气质增添了优雅美感。而雪纺面料,使你在夏日也能穿出清爽感。"}
94 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*版型#显瘦*风格#复古*风格#文艺*风格#中国风*风格#性感*图案#复古*图案#刺绣*裙型#a字*裙领型#v领", "response": "超级具有中国风气息的一款裙子,带着古典的柔婉。花朵刺绣的运用,色彩缤纷靓丽,冲击视觉,演绎复古文艺范儿。经典的气质v领,既凸显了小性感与时尚,又起到点睛的效果。腰部系的设计,配上a字版型,显瘦又遮肚子。"}
95 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*风格#清新*风格#性感*图案#线条*衣样式#马甲*衣领型#翻领*衣款式#露背*衣款式#绑带*衣款式#吊带*衣款式#收腰", "response": "小吊带马甲叠穿造型,年轻而不失时尚格调,有着绑带收腰设计,强调出纤细的腰肢,摩登帅气;小翻领露出纤细修长的脖颈线条,散发清爽利落的小清新气息;性感交叉露背设计,别致吸睛,女人味十足;高腰伞形裙摆自然撑开,上身塑造黄金比例,突显得腰更细,巧妙地修饰身型。"}
96 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*版型#显瘦*颜色#黑白*风格#英伦*风格#简约*图案#格子*图案#线条*衣样式#外套*衣样式#西装*衣门襟#一粒扣", "response": "这款西装外套,版型加长修身,能更好凸显成熟与稳重。细细密密的黑白图案,远远看形成自然的格纹,时髦英伦范儿。平整肩线将线条感拉伸,让身姿显得更挺拔有型。一粒扣设计,简约大气。"}
97 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裙*风格#潮*图案#线条*图案#撞色*裙领型#圆领", "response": "采用经典的圆领设计,修饰颈部线条的同时,且上身穿着舒适不易变形,轻松演绎时髦造型。大面积撞色贴花装饰,无疑是点睛之笔,为简洁的款式轮廓带来了更多的视觉层次感与潮流气息。与众不同的你,不在畏惧撞衫的尴尬。"}
98 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#雪纺*颜色#纯色*风格#清新*图案#纯色*图案#碎花*衣样式#衬衫*衣款式#荷叶边", "response": "这件荷叶边雪纺碎花衬衫和其他的碎花衬衫相比整体的风格会更优雅柔美一些。颜色上也是比较清新的花型配色和纯色的大身相结合,会让人看着很舒服,而且每个碎花之间都会限视觉上不会觉得太紧密,更有法式的浪漫优雅。"}
99 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#上衣*材质#蕾丝*风格#简约*风格#青春*风格#潮*风格#性感*图案#线条*图案#蕾丝*衣样式#雪纺衫*衣领型#圆领*衣款式#勾花镂空", "response": "这款时尚镂空雪纺衫,带有性感蕾丝工艺,精致百搭的圆领设计,彰显显独特的质感。背部线条流畅,笔挺而有型干练,彰显潮流时尚之风。走线十分笔直,针脚均匀,尽显裁缝之细致。简约时尚的透视蕾丝袖口,彰显成熟又不乏活力的青春气质。青春优雅的独特风格,流露出满满的潮流感。"}
100 | {"prompt": "请根据以下标签为商品编写一段广告\n类型#裤*风格#休闲*裤长#短裤", "response": "来自英国的这款儿童休闲短裤,以趣味的小恐龙图案满印裤身,可爱童真,彰显出宝宝的活泼天真范儿。柔软的全棉布料质地,手感细腻顺滑,亲和宝宝的肌肤,带来舒适自在的穿着体验。"}
--------------------------------------------------------------------------------
/limitations/factual_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/limitations/factual_error.png
--------------------------------------------------------------------------------
/limitations/math_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/limitations/math_error.png
--------------------------------------------------------------------------------
/limitations/self-confusion_google.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/limitations/self-confusion_google.jpg
--------------------------------------------------------------------------------
/limitations/self-confusion_openai.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/limitations/self-confusion_openai.jpg
--------------------------------------------------------------------------------
/limitations/self-confusion_tencent.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/limitations/self-confusion_tencent.jpg
--------------------------------------------------------------------------------
/model_api.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from transformers import AutoTokenizer, AutoModel
3 | import uvicorn, json, datetime
4 | import torch
5 |
6 | DEVICE = "cuda"
7 | DEVICE_ID = "0"
8 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
9 |
10 |
11 | def torch_gc():
12 | if torch.cuda.is_available():
13 | with torch.cuda.device(CUDA_DEVICE):
14 | torch.cuda.empty_cache()
15 | torch.cuda.ipc_collect()
16 |
17 |
18 | app = FastAPI()
19 |
20 |
21 | @app.post("/")
22 | async def create_item(request: Request):
23 | global model, tokenizer
24 | json_post_raw = await request.json()
25 | json_post = json.dumps(json_post_raw)
26 | json_post_list = json.loads(json_post)
27 | prompt = json_post_list.get('prompt')
28 | history = json_post_list.get('history')
29 | max_length = json_post_list.get('max_length')
30 | top_p = json_post_list.get('top_p')
31 | temperature = json_post_list.get('temperature')
32 | response, history = model.chat(tokenizer,
33 | prompt,
34 | history=history,
35 | max_length=max_length if max_length else 2048,
36 | top_p=top_p if top_p else 0.7,
37 | temperature=temperature if temperature else 0.95)
38 | now = datetime.datetime.now()
39 | time = now.strftime("%Y-%m-%d %H:%M:%S")
40 | answer = {
41 | "response": response,
42 | "history": history,
43 | "status": 200,
44 | "time": time
45 | }
46 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"'
47 | print(log)
48 | torch_gc()
49 | return answer
50 |
51 |
52 | if __name__ == '__main__':
53 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
54 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
55 | model.eval()
56 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
57 |
--------------------------------------------------------------------------------
/ptuning/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B-PT
2 | 本仓库实现了对于 ChatGLM-6B 模型基于 [P-Tuning v2](https://github.com/THUDM/P-tuning-v2) 的微调。P-Tuning v2 将需要微调的参数量减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 显存即可运行。
3 |
4 | 下面以 [ADGEN](https://aclanthology.org/D19-1321.pdf) (广告生成) 数据集为例介绍代码的使用方法。
5 |
6 | *Read this in [English](README_en.md).
7 |
8 | ## 软件依赖
9 | 运行微调需要4.27.1版本的`transformers`。除 ChatGLM-6B 的依赖之外,还需要安装以下依赖
10 | ```
11 | pip install rouge_chinese nltk jieba datasets
12 | ```
13 | ## 使用方法
14 |
15 | ### 下载数据集
16 | ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。
17 |
18 | ```json
19 | {
20 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
21 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
22 | }
23 | ```
24 |
25 | 从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 ADGEN 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
26 |
27 | ### 训练
28 |
29 | #### P-Tuning v2
30 |
31 | 运行以下指令进行训练:
32 | ```shell
33 | bash train.sh
34 | ```
35 | `train.sh` 中的 `PRE_SEQ_LEN` 和 `LR` 分别是 soft prompt 长度和训练的学习率,可以进行调节以取得最佳的效果。P-Tuning-v2 方法会冻结全部的模型参数,可通过调整 `quantization_bit` 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。
36 |
37 | 在默认配置 `quantization_bit=4`、`per_device_train_batch_size=1`、`gradient_accumulation_steps=16` 下,INT4 的模型参数被冻结,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 `per_device_train_batch_size` 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
38 |
39 | 如果你想要[从本地加载模型](../README_en.md#load-the-model-locally),可以将 `train.sh` 中的 `THUDM/chatglm-6b` 改为你本地的模型路径。
40 |
41 | #### Finetune
42 |
43 | 如果需要进行全参数的 Finetune,需要安装 [Deepspeed](https://github.com/microsoft/DeepSpeed),然后运行以下指令:
44 |
45 | ```shell
46 | bash ds_train_finetune.sh
47 | ```
48 |
49 | ### 推理
50 |
51 | 在 P-tuning v2 训练时模型只保存 PrefixEncoder 部分的参数,所以在推理时需要同时加载原 ChatGLM-6B 模型以及 PrefixEncoder 的权重,因此需要指定 `evaluate.sh` 中的参数:
52 |
53 | ```shell
54 | --model_name_or_path THUDM/chatglm-6b
55 | --ptuning_checkpoint $CHECKPOINT_PATH
56 | ```
57 |
58 | 仍然兼容旧版全参保存的 Checkpoint,只需要跟之前一样设定 `model_name_or_path`:
59 |
60 | ```shell
61 | --model_name_or_path $CHECKPOINT_PATH
62 | ```
63 |
64 | 评测指标为中文 Rouge score 和 BLEU-4。生成的结果保存在
65 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`。
66 |
67 | ### 例子
68 | #### 示例1
69 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
70 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
71 | * Output[微调前]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
72 | * Output[微调后]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
73 |
74 | #### 示例2
75 |
76 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
77 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
78 | * Output[微调前]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
79 | * Output[微调后]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
80 |
81 | ### 评估结果
82 |
83 | | | Finetune | P-tuning v2 | LoRA |
84 | | ------------- | ----------- | ----- | ------------- |
85 | | BLEU-4 | 8.01 | 8.10 | 7.62 |
86 | | Rouge-1 | 31.23 | 31.12 | 30.60 |
87 | | Rouge-2 | 7.36 | 7.11 | 6.96 |
88 | | Rouge-l | 25.08 | 24.97 | 24.80 |
89 | | Training Loss | 3.00 | 3.74 | 3.32 |
90 |
91 |
92 |
93 | #### 实验设置
94 |
95 | ```
96 | max_source_length=64
97 | max_target_length=64
98 | max_steps=3000
99 | ```
100 |
101 | ##### P-tuning v2
102 |
103 | ```
104 | pre_seq_len=128
105 | learning_rate=2e-2
106 | quantization_bit=4
107 | per_device_train_batch_size=16
108 | gradient_accumulation_steps=1
109 | ```
110 |
111 | ##### Finetune
112 |
113 | ```
114 | learning_rate=1e-4
115 | fp16
116 | num_gpus=4
117 | per_device_train_batch_size=4
118 | gradient_accumulation_steps=1
119 | ```
120 |
121 | ##### LoRA
122 |
123 | 实现采用的是 [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
124 |
125 | ```
126 | learning_rate=5e-4
127 | per_device_train_batch_size=16
128 | gradient_accumulation_steps=1
129 | ```
130 |
131 | ## 模型部署
132 | 首先载入Tokenizer:
133 |
134 | ```python
135 | from transformers import AutoConfig, AutoModel, AutoTokenizer
136 |
137 | # 载入Tokenizer
138 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
139 | ```
140 |
141 | 1. 如果需要加载的是新 Checkpoint(只包含 PrefixEncoder 参数):
142 |
143 | ```python
144 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
145 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
146 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
147 | new_prefix_state_dict = {}
148 | for k, v in prefix_state_dict.items():
149 | if k.startswith("transformer.prefix_encoder."):
150 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
151 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
152 | ```
153 | 注意你可能需要将 `pre_seq_len` 改成你训练时的实际值。如果你是[从本地加载模型](https://github.com/THUDM/ChatGLM-6B#%E4%BB%8E%E6%9C%AC%E5%9C%B0%E5%8A%A0%E8%BD%BD%E6%A8%A1%E5%9E%8B)的话,需要将 `THUDM/chatglm-6b` 改成本地的模型路径(注意不是checkpoint路径)。
154 |
155 | 2. 如果需要加载的是旧 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 参数),或者进行的是全参数微调,则直接加载整个 Checkpoint:
156 |
157 | ```python
158 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
159 | ```
160 |
161 | 之后根据需求可以进行量化,也可以直接使用:
162 |
163 | ```python
164 | # Comment out the following line if you don't use quantization
165 | model = model.quantize(4)
166 | model = model.half().cuda()
167 | model.transformer.prefix_encoder.float()
168 | model = model.eval()
169 |
170 | response, history = model.chat(tokenizer, "你好", history=[])
171 | ```
172 |
173 | **[23/04/19]** 你也可以直接运行支持加载 P-Tuning v2 checkpoint 的 [web demo](./web_demo.py)
174 | ```shell
175 | bash web_demo.sh
176 | ```
177 | 可能需要修改 [web_demo.sh](./web_demo.sh) 的内容以符合你实际的 checkpoint 情况。
178 |
179 | ## 使用自己的数据集
180 | 修改 `train.sh` 和 `evaluate.sh` 中的 `train_file`、`validation_file`和`test_file`为你自己的 JSON 格式数据集路径,并将 `prompt_column` 和 `response_column` 改为 JSON 文件中输入文本和输出文本对应的 KEY。可能还需要增大 `max_source_length` 和 `max_target_length` 来匹配你自己的数据集中的最大输入输出长度。
181 |
182 | ## 对话数据集
183 |
184 | 如需要使用多轮对话数据对模型进行微调,可以提供聊天历史,例如以下是一个三轮对话的训练数据:
185 |
186 | ```json lines
187 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
188 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
189 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
190 | ```
191 |
192 | 训练时需要指定 `--history_column` 为数据中聊天历史的 key(在此例子中是 `history`),将自动把聊天历史拼接。要注意超过输入长度 `max_source_length` 的内容会被截断。
193 |
194 | 可以参考以下指令:
195 |
196 | ```shell
197 | bash train_chat.sh
198 | ```
199 |
200 | ## 引用
201 |
202 | ```
203 | @inproceedings{liu2022p,
204 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
205 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
206 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
207 | pages={61--68},
208 | year={2022}
209 | }
210 | ```
211 |
212 |
213 |
214 |
--------------------------------------------------------------------------------
/ptuning/README_en.md:
--------------------------------------------------------------------------------
1 | # ChatGLM-6B-PT
2 | This repository implements tuning of the ChatGLM-6B model based on [P-Tuning v2](https://github.com/THUDM/P-tuning-v2). P-Tuning v2 reduces the amount of parameters that need to be optimized to 0.1% of the full fine-tuning, and then through model quantization, Gradient Checkpoint and other methods, it only needs a minimum of 7GB of video memory to run.
3 |
4 | The following uses the [ADGEN](https://aclanthology.org/D19-1321.pdf) (advertising generation) dataset as an example to introduce how to use the code.
5 |
6 | ## Software dependencies
7 | Running p-tuning requires version 4.27.1 of `transformers`. In addition to the dependencies of ChatGLM-6B, the following dependencies are required
8 | ```
9 | pip install rouge_chinese nltk jieba datasets
10 | ```
11 | ## Instructions
12 |
13 | ### Download the dataset
14 | The task of the ADGEN dataset is to generate an advertisement word (summary) based on the input (content).
15 |
16 | ```json
17 | {
18 | "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
19 | "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
20 | }
21 | ```
22 |
23 | From [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) or [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) download the processed ADGEN dataset, and put the decompressed `AdvertiseGen` directory into this directory.
24 |
25 | ### Training
26 |
27 | #### P-Tuning v2
28 |
29 | Run the following commands for training:
30 | ```shell
31 | bash train.sh
32 | ```
33 | `PRE_SEQ_LEN` and `LR` in `train.sh` are soft prompt length and training learning rate respectively, which can be adjusted to achieve the best results. The P-Tuning-v2 method will freeze all model parameters, and the quantization level of the original model can be adjusted by adjusting `quantization_bit`. If this option is not added, it will be loaded with FP16 precision.
34 |
35 | Under the default configuration of `per_device_train_batch_size=1`, `gradient_accumulation_steps=16`, the model parameters of INT4 are frozen, and a training iteration will perform 16 cumulative forward and backward propagations with a batch size of 1, which is equivalent to the total batch size of 16, and only 6.7G GPU memory is required at this time with `quantization_bit=4`. If you want to improve the training efficiency under the same batch size, you can increase the value of `per_device_train_batch_size` while keeping the product of the two unchanged, but it will also bring more GPU memory consumption, please adjust it according to the actual situation.
36 |
37 | If you want to [load the model locally](../README_en.md#load-the-model-locally), you can change `THUDM/chatglm-6b` in `train.sh` to your local model path.
38 |
39 | #### Finetune
40 | To finetune the full parameters, you need to install [Deepspeed](https://github.com/microsoft/DeepSpeed), and then run the following command:
41 |
42 | ```shell
43 | bash ds_train_finetune.sh
44 | ```
45 |
46 | ### Inference
47 |
48 | During P-tuning v2 training, the model only saves the parameters of the PrefixEncoder part, so the original ChatGLM-6B model and the weight of the PrefixEncoder need to be loaded at the same time during inference, and the arguments need to be specified in `evaluate.sh`:
49 |
50 | ```shell
51 | --model_name_or_path THUDM/chatglm-6b
52 | --ptuning_checkpoint $CHECKPOINT_PATH
53 | ```
54 |
55 | It is still compatible with the old version of Checkpoint saved with full parameters, just set `model_name_or_path` as before:
56 |
57 | ```shell
58 | --model_name_or_path $CHECKPOINT_PATH
59 | ```
60 |
61 | The evaluation indicators are Chinese Rouge score and BLEU-4. The generated results are saved in
62 | `./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt`.
63 |
64 | ### Example
65 | #### Example 1
66 | * Input: 类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞
67 | * Label: 简约而不简单的牛仔外套,白色的衣身十分百搭。衣身多处有做旧破洞设计,打破单调乏味,增加一丝造型看点。衣身后背处有趣味刺绣装饰,丰富层次感,彰显别样时尚。
68 | * Output[before tuning]: 这件上衣的材质是牛仔布,颜色是白色,风格是简约,图案是刺绣,衣样式是外套,衣款式是破洞。
69 | * Output[after tuning]: 这是一款简约的牛仔外套,破洞设计,将牛仔布破洞,带来一种随意与个性。破洞的牛仔外套,展现出时尚气息,带来一种休闲感。同时,刺绣图案,让整件外套更加立体。
70 |
71 | #### Example 2
72 |
73 | * Input: 类型#裙\*版型#显瘦\*风格#文艺\*风格#简约\*图案#印花\*图案#撞色\*裙下摆#压褶\*裙长#连衣裙\*裙领型#圆领
74 | * Label: 文艺个性的印花连衣裙,藏青色底蕴,低调又大气,撞色太阳花分布整个裙身,绚丽而美好,带来时尚减龄的气质。基础款的舒适圆领,简约不失大方,勾勒精致脸庞。领后是一粒包布扣固定,穿脱十分方便。前片立体的打褶设计,搭配后片压褶的做工,增添层次和空间感,显瘦又有型。
75 | * Output[before tuning]: 类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领 1\. 连衣裙:简约风格,裙长为膝盖以上,裙领型为圆领。2\. 裙下摆:压褶设计,使裙摆呈现出流畅的褶皱效果。3\. 裙领型:裙领型为圆领,使穿上连衣裙后更加有型。4\. 版型:采用显瘦设计,让连衣裙看起来更加苗条。5\. 风格:文艺风格,让连衣裙更加有内涵和品味。6\. 图案:印花设计,在连衣裙上印有独特的图案。7\. 撞色:采用撞色设计,让连衣裙在色彩上更加鲜明、富有层次感。
76 | * Output[after tuning]: 这是一款文艺范的连衣裙,以印花为元素,采用简约的印花,既能够突出文艺气质,又能够展现简约风。在印花的同时又有领子和裙摆的压褶设计,更加凸显文艺气质。简约而不会过于单调,搭配出街,穿着十分舒适。
77 |
78 | ### evaluation result
79 |
80 | | | Finetune | P-tuning v2 | LoRA |
81 | | ------------- | ----------- | ----- | ------------- |
82 | | BLEU-4 | 8.01 | 8.10 | 7.62 |
83 | | Rouge-1 | 31.23 | 31.12 | 30.60 |
84 | | Rouge-2 | 7.36 | 7.11 | 6.96 |
85 | | Rouge-l | 25.08 | 24.97 | 24.80 |
86 | | Training Loss | 3.00 | 3.74 | 3.32 |
87 |
88 | #### Experiment Settings
89 |
90 | ```
91 | max_source_length=64
92 | max_target_length=64
93 | max_steps=3000
94 | ```
95 |
96 | ##### P-tuning v2
97 |
98 | ```
99 | pre_seq_len=128
100 | learning_rate=2e-2
101 | quantization_bit=4
102 | per_device_train_batch_size=16
103 | gradient_accumulation_steps=1
104 | ```
105 |
106 | ##### Finetune
107 |
108 | ```
109 | learning_rate=1e-4
110 | fp16
111 | num_gpus=4
112 | per_device_train_batch_size=4
113 | gradient_accumulation_steps=1
114 | ```
115 |
116 | ##### LoRA
117 |
118 | The implementation uses [simple_thu_chatglm6b](https://github.com/yuanzhoulvpi2017/zero_nlp/tree/main/simple_thu_chatglm6b)
119 |
120 | ```
121 | learning_rate=5e-4
122 | per_device_train_batch_size=16
123 | gradient_accumulation_steps=1
124 | ```
125 |
126 | ## Model Deployment
127 | First load the tokenizer:
128 |
129 | ```python
130 | from transformers import AutoConfig, AutoModel, AutoTokenizer
131 |
132 | # Load Tokenizer
133 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
134 | ```
135 |
136 | 1. If a new Checkpoint needs to be loaded (only contains the PrefixEncoder parameter):
137 |
138 | ```python
139 | config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128)
140 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True)
141 | prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
142 | new_prefix_state_dict = {}
143 | for k, v in prefix_state_dict.items():
144 | if k.startswith("transformer.prefix_encoder."):
145 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
146 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
147 | ```
148 | Note that you may need to change `pre_seq_len` to the actual value of your training. If you [load model from local](../README_en.md#load-the-model-locally), you need to change `THUDM/chatglm-6b` to the local model path (not the checkpoint path).
149 |
150 | 2. If you need to load the old checkpoint (including both ChatGLM-6B and PrefixEncoder parameters), or perform full parameter fine-tuning, then directly load the entire checkpoint:
151 |
152 | ```python
153 | model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
154 | ```
155 |
156 | Then it can be quantified according to the needs, or it can be used directly:
157 |
158 | ```python
159 | # Comment out the following line if you don't use quantization
160 | model = model. quantize(4)
161 | model = model.half().cuda()
162 | model.transformer.prefix_encoder.float()
163 | model = model.eval()
164 |
165 | response, history = model.chat(tokenizer, "Hello", history=[])
166 | ```
167 |
168 | **[23/04/19]** You can also directly run [web demo](./web_demo.py) which supports loading P-Tuning v2 checkpoint
169 | ```shell
170 | bash web_demo.sh
171 | ```
172 | It may be necessary to modify the content of [web_demo.sh](./web_demo.sh) to match your actual checkpoint situation.
173 |
174 | ## Use your own dataset
175 | Modify `train_file`, `validation_file` and `test_file` in `train.sh` and `evaluate.sh` to your own JSON format dataset paths, and change `prompt_column` and `response_column` to the keys in the JSON file corresponding to input text and output text.
176 | You may also need to increase `max_source_length` and `max_target_length` to match the maximum input and output lengths in your own dataset.
177 |
178 | ## Dialog Dataset
179 |
180 | If you need to use multiple rounds of dialogue data to train the model, you can provide chat history. For example, the following is the training data for a three-round dialogue:
181 |
182 | ```json lines
183 | {"prompt": "长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "response": "用电脑能读数据流吗?水温多少", "history": []}
184 | {"prompt": "95", "response": "上下水管温差怎么样啊?空气是不是都排干净了呢?", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"]]}
185 | {"prompt": "是的。上下水管都好的", "response": "那就要检查线路了,一般风扇继电器是由电脑控制吸合的,如果电路存在断路,或者电脑坏了的话会出现继电器不吸合的情况!", "history": [["长城h3风扇不转。继电器好的。保险丝好的传感器新的风扇也新的这是为什么。就是继电器缺一个信号线", "用电脑能读数据流吗?水温多少"], ["95", "上下水管温差怎么样啊?空气是不是都排干净了呢?"]]}
186 | ```
187 |
188 | During training, you need to specify `--history_column` as the key of the chat history in the data (`history` in this example), and the chat history will be stitched automatically. Note that content exceeding the input length `max_source_length` will be truncated.
189 |
190 | You can refer to the following instructions:
191 |
192 | ```shell
193 | bash train_chat.sh
194 | ```
195 |
196 | ## Citation
197 |
198 | ```
199 | @inproceedings{liu2022p,
200 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
201 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
202 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
203 | pages={61--68},
204 | year={2022}
205 | }
206 | ```
--------------------------------------------------------------------------------
/ptuning/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 |
64 | @dataclass
65 | class DataTrainingArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 | """
69 |
70 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
71 |
72 | dataset_name: Optional[str] = field(
73 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
74 | )
75 | dataset_config_name: Optional[str] = field(
76 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
77 | )
78 | prompt_column: Optional[str] = field(
79 | default=None,
80 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
81 | )
82 | response_column: Optional[str] = field(
83 | default=None,
84 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
85 | )
86 | history_column: Optional[str] = field(
87 | default=None,
88 | metadata={"help": "The name of the column in the datasets containing the history of chat."},
89 | )
90 | train_file: Optional[str] = field(
91 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
92 | )
93 | validation_file: Optional[str] = field(
94 | default=None,
95 | metadata={
96 | "help": (
97 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
98 | )
99 | },
100 | )
101 | test_file: Optional[str] = field(
102 | default=None,
103 | metadata={
104 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
105 | },
106 | )
107 | overwrite_cache: bool = field(
108 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
109 | )
110 | preprocessing_num_workers: Optional[int] = field(
111 | default=None,
112 | metadata={"help": "The number of processes to use for the preprocessing."},
113 | )
114 | max_source_length: Optional[int] = field(
115 | default=1024,
116 | metadata={
117 | "help": (
118 | "The maximum total input sequence length after tokenization. Sequences longer "
119 | "than this will be truncated, sequences shorter will be padded."
120 | )
121 | },
122 | )
123 | max_target_length: Optional[int] = field(
124 | default=128,
125 | metadata={
126 | "help": (
127 | "The maximum total sequence length for target text after tokenization. Sequences longer "
128 | "than this will be truncated, sequences shorter will be padded."
129 | )
130 | },
131 | )
132 | val_max_target_length: Optional[int] = field(
133 | default=None,
134 | metadata={
135 | "help": (
136 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
137 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
138 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
139 | "during ``evaluate`` and ``predict``."
140 | )
141 | },
142 | )
143 | pad_to_max_length: bool = field(
144 | default=False,
145 | metadata={
146 | "help": (
147 | "Whether to pad all samples to model maximum sentence length. "
148 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
149 | "efficient on GPU but very bad for TPU."
150 | )
151 | },
152 | )
153 | max_train_samples: Optional[int] = field(
154 | default=None,
155 | metadata={
156 | "help": (
157 | "For debugging purposes or quicker training, truncate the number of training examples to this "
158 | "value if set."
159 | )
160 | },
161 | )
162 | max_eval_samples: Optional[int] = field(
163 | default=None,
164 | metadata={
165 | "help": (
166 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
167 | "value if set."
168 | )
169 | },
170 | )
171 | max_predict_samples: Optional[int] = field(
172 | default=None,
173 | metadata={
174 | "help": (
175 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
176 | "value if set."
177 | )
178 | },
179 | )
180 | num_beams: Optional[int] = field(
181 | default=None,
182 | metadata={
183 | "help": (
184 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
185 | "which is used during ``evaluate`` and ``predict``."
186 | )
187 | },
188 | )
189 | ignore_pad_token_for_loss: bool = field(
190 | default=True,
191 | metadata={
192 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
193 | },
194 | )
195 | source_prefix: Optional[str] = field(
196 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
197 | )
198 |
199 | forced_bos_token: Optional[str] = field(
200 | default=None,
201 | metadata={
202 | "help": (
203 | "The token to force as the first generated token after the decoder_start_token_id."
204 | "Useful for multilingual models like mBART where the first generated token"
205 | "needs to be the target language token (Usually it is the target language token)"
206 | )
207 | },
208 | )
209 |
210 |
211 |
212 | def __post_init__(self):
213 | if self.dataset_name is None and self.train_file is None and self.validation_file is None and self.test_file is None:
214 | raise ValueError("Need either a dataset name or a training/validation/test file.")
215 | else:
216 | if self.train_file is not None:
217 | extension = self.train_file.split(".")[-1]
218 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
219 | if self.validation_file is not None:
220 | extension = self.validation_file.split(".")[-1]
221 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
222 | if self.val_max_target_length is None:
223 | self.val_max_target_length = self.max_target_length
224 |
225 |
--------------------------------------------------------------------------------
/ptuning/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/ptuning/ds_train_finetune.sh:
--------------------------------------------------------------------------------
1 |
2 | LR=1e-4
3 |
4 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
5 |
6 | deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
7 | --deepspeed deepspeed.json \
8 | --do_train \
9 | --train_file AdvertiseGen/train.json \
10 | --test_file AdvertiseGen/dev.json \
11 | --prompt_column content \
12 | --response_column summary \
13 | --overwrite_cache \
14 | --model_name_or_path THUDM/chatglm-6b \
15 | --output_dir ./output/adgen-chatglm-6b-ft-$LR \
16 | --overwrite_output_dir \
17 | --max_source_length 64 \
18 | --max_target_length 64 \
19 | --per_device_train_batch_size 4 \
20 | --per_device_eval_batch_size 1 \
21 | --gradient_accumulation_steps 1 \
22 | --predict_with_generate \
23 | --max_steps 5000 \
24 | --logging_steps 10 \
25 | --save_steps 1000 \
26 | --learning_rate $LR \
27 | --fp16
28 |
29 |
--------------------------------------------------------------------------------
/ptuning/evaluate.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | CHECKPOINT=adgen-chatglm-6b-pt-128-2e-2
3 | STEP=3000
4 |
5 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
6 | --do_predict \
7 | --validation_file AdvertiseGen/dev.json \
8 | --test_file AdvertiseGen/dev.json \
9 | --overwrite_cache \
10 | --prompt_column content \
11 | --response_column summary \
12 | --model_name_or_path THUDM/chatglm-6b \
13 | --ptuning_checkpoint ./output/$CHECKPOINT/checkpoint-$STEP \
14 | --output_dir ./output/$CHECKPOINT \
15 | --overwrite_output_dir \
16 | --max_source_length 64 \
17 | --max_target_length 64 \
18 | --per_device_eval_batch_size 1 \
19 | --predict_with_generate \
20 | --pre_seq_len $PRE_SEQ_LEN \
21 | --quantization_bit 4
22 |
--------------------------------------------------------------------------------
/ptuning/evaluate_finetune.sh:
--------------------------------------------------------------------------------
1 | CHECKPOINT=adgen-chatglm-6b-ft-1e-4
2 | STEP=3000
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_predict \
6 | --validation_file AdvertiseGen/dev.json \
7 | --test_file AdvertiseGen/dev.json \
8 | --overwrite_cache \
9 | --prompt_column content \
10 | --response_column summary \
11 | --model_name_or_path ./output/$CHECKPOINT/checkpoint-$STEP \
12 | --output_dir ./output/$CHECKPOINT \
13 | --overwrite_output_dir \
14 | --max_source_length 256 \
15 | --max_target_length 256 \
16 | --per_device_eval_batch_size 1 \
17 | --predict_with_generate \
18 | --fp16_full_eval
19 |
--------------------------------------------------------------------------------
/ptuning/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | import json
25 |
26 | import numpy as np
27 | from datasets import load_dataset
28 | import jieba
29 | from rouge_chinese import Rouge
30 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
31 | import torch
32 |
33 | import transformers
34 | from transformers import (
35 | AutoConfig,
36 | AutoModel,
37 | AutoTokenizer,
38 | AutoTokenizer,
39 | DataCollatorForSeq2Seq,
40 | HfArgumentParser,
41 | Seq2SeqTrainingArguments,
42 | set_seed,
43 | )
44 | from trainer_seq2seq import Seq2SeqTrainer
45 |
46 | from arguments import ModelArguments, DataTrainingArguments
47 |
48 | logger = logging.getLogger(__name__)
49 |
50 | def main():
51 |
52 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
53 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
54 | # If we pass only one argument to the script and it's the path to a json file,
55 | # let's parse it to get our arguments.
56 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
57 | else:
58 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
59 |
60 | # Setup logging
61 | logging.basicConfig(
62 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
63 | datefmt="%m/%d/%Y %H:%M:%S",
64 | handlers=[logging.StreamHandler(sys.stdout)],
65 | )
66 |
67 | if training_args.should_log:
68 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
69 | transformers.utils.logging.set_verbosity_info()
70 |
71 | log_level = training_args.get_process_log_level()
72 | logger.setLevel(log_level)
73 | # datasets.utils.logging.set_verbosity(log_level)
74 | transformers.utils.logging.set_verbosity(log_level)
75 | transformers.utils.logging.enable_default_handler()
76 | transformers.utils.logging.enable_explicit_format()
77 |
78 | # Log on each process the small summary:
79 | logger.warning(
80 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
81 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
82 | )
83 | logger.info(f"Training/evaluation parameters {training_args}")
84 |
85 | # Set seed before initializing model.
86 | set_seed(training_args.seed)
87 |
88 | # Load dataset
89 | data_files = {}
90 | if data_args.train_file is not None:
91 | data_files["train"] = data_args.train_file
92 | extension = data_args.train_file.split(".")[-1]
93 | if data_args.validation_file is not None:
94 | data_files["validation"] = data_args.validation_file
95 | extension = data_args.validation_file.split(".")[-1]
96 | if data_args.test_file is not None:
97 | data_files["test"] = data_args.test_file
98 | extension = data_args.test_file.split(".")[-1]
99 |
100 | raw_datasets = load_dataset(
101 | extension,
102 | data_files=data_files,
103 | cache_dir=model_args.cache_dir,
104 | use_auth_token=True if model_args.use_auth_token else None,
105 | )
106 |
107 | # Load pretrained model and tokenizer
108 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
109 | config.pre_seq_len = model_args.pre_seq_len
110 | config.prefix_projection = model_args.prefix_projection
111 |
112 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
113 |
114 | if model_args.ptuning_checkpoint is not None:
115 | # Evaluation
116 | # Loading extra state dict of prefix encoder
117 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
118 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
119 | new_prefix_state_dict = {}
120 | for k, v in prefix_state_dict.items():
121 | if k.startswith("transformer.prefix_encoder."):
122 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
123 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
124 | else:
125 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
126 |
127 | if model_args.quantization_bit is not None:
128 | print(f"Quantized to {model_args.quantization_bit} bit")
129 | model = model.quantize(model_args.quantization_bit)
130 | if model_args.pre_seq_len is not None:
131 | # P-tuning v2
132 | model = model.half()
133 | model.transformer.prefix_encoder.float()
134 | else:
135 | # Finetune
136 | model = model.float()
137 |
138 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
139 |
140 | # Preprocessing the datasets.
141 | # We need to tokenize inputs and targets.
142 | if training_args.do_train:
143 | column_names = raw_datasets["train"].column_names
144 | elif training_args.do_eval:
145 | column_names = raw_datasets["validation"].column_names
146 | elif training_args.do_predict:
147 | column_names = raw_datasets["test"].column_names
148 | else:
149 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
150 | return
151 |
152 | # Get the column names for input/target.
153 | prompt_column = data_args.prompt_column
154 | response_column = data_args.response_column
155 | history_column = data_args.history_column
156 |
157 | # Temporarily set max_target_length for training.
158 | max_target_length = data_args.max_target_length
159 |
160 | def preprocess_function_eval(examples):
161 | inputs, targets = [], []
162 | for i in range(len(examples[prompt_column])):
163 | if examples[prompt_column][i] and examples[response_column][i]:
164 | query = examples[prompt_column][i]
165 | if history_column is None or len(examples[history_column][i]) == 0:
166 | prompt = query
167 | else:
168 | prompt = ""
169 | history = examples[history_column][i]
170 | for turn_idx, (old_query, response) in enumerate(history):
171 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
172 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
173 | inputs.append(prompt)
174 | targets.append(examples[response_column][i])
175 |
176 | inputs = [prefix + inp for inp in inputs]
177 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
178 | labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
179 |
180 | if data_args.ignore_pad_token_for_loss:
181 | labels["input_ids"] = [
182 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
183 | ]
184 | model_inputs["labels"] = labels["input_ids"]
185 |
186 | return model_inputs
187 |
188 | def preprocess_function_train(examples):
189 | max_seq_length = data_args.max_source_length + data_args.max_target_length
190 |
191 | model_inputs = {
192 | "input_ids": [],
193 | "labels": [],
194 | }
195 | for i in range(len(examples[prompt_column])):
196 | if examples[prompt_column][i] and examples[response_column][i]:
197 | query, answer = examples[prompt_column][i], examples[response_column][i]
198 |
199 | if history_column is None:
200 | prompt = query
201 | else:
202 | prompt = ""
203 | history = examples[history_column][i]
204 | for turn_idx, (old_query, response) in enumerate(history):
205 | prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
206 | prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
207 |
208 | prompt = prefix + prompt
209 | a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
210 | b_ids = tokenizer.encode(text=answer, add_special_tokens=False)
211 |
212 | if len(a_ids) > data_args.max_source_length - 1:
213 | a_ids = a_ids[: data_args.max_source_length - 1]
214 |
215 | if len(b_ids) > data_args.max_target_length - 2:
216 | b_ids = b_ids[: data_args.max_target_length - 2]
217 |
218 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
219 |
220 | context_length = input_ids.index(tokenizer.bos_token_id)
221 | mask_position = context_length - 1
222 | labels = [-100] * context_length + input_ids[mask_position+1:]
223 |
224 | pad_len = max_seq_length - len(input_ids)
225 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
226 | labels = labels + [tokenizer.pad_token_id] * pad_len
227 | if data_args.ignore_pad_token_for_loss:
228 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
229 |
230 | model_inputs["input_ids"].append(input_ids)
231 | model_inputs["labels"].append(labels)
232 |
233 | return model_inputs
234 |
235 | def print_dataset_example(example):
236 | print("input_ids",example["input_ids"])
237 | print("inputs", tokenizer.decode(example["input_ids"]))
238 | print("label_ids", example["labels"])
239 | print("labels", tokenizer.decode(example["labels"]))
240 |
241 | if training_args.do_train:
242 | if "train" not in raw_datasets:
243 | raise ValueError("--do_train requires a train dataset")
244 | train_dataset = raw_datasets["train"]
245 | if data_args.max_train_samples is not None:
246 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
247 | train_dataset = train_dataset.select(range(max_train_samples))
248 | with training_args.main_process_first(desc="train dataset map pre-processing"):
249 | train_dataset = train_dataset.map(
250 | preprocess_function_train,
251 | batched=True,
252 | num_proc=data_args.preprocessing_num_workers,
253 | remove_columns=column_names,
254 | load_from_cache_file=not data_args.overwrite_cache,
255 | desc="Running tokenizer on train dataset",
256 | )
257 | print_dataset_example(train_dataset[0])
258 |
259 | if training_args.do_eval:
260 | max_target_length = data_args.val_max_target_length
261 | if "validation" not in raw_datasets:
262 | raise ValueError("--do_eval requires a validation dataset")
263 | eval_dataset = raw_datasets["validation"]
264 | if data_args.max_eval_samples is not None:
265 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
266 | eval_dataset = eval_dataset.select(range(max_eval_samples))
267 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
268 | eval_dataset = eval_dataset.map(
269 | preprocess_function_eval,
270 | batched=True,
271 | num_proc=data_args.preprocessing_num_workers,
272 | remove_columns=column_names,
273 | load_from_cache_file=not data_args.overwrite_cache,
274 | desc="Running tokenizer on validation dataset",
275 | )
276 | print_dataset_example(eval_dataset[0])
277 |
278 | if training_args.do_predict:
279 | max_target_length = data_args.val_max_target_length
280 | if "test" not in raw_datasets:
281 | raise ValueError("--do_predict requires a test dataset")
282 | predict_dataset = raw_datasets["test"]
283 | if data_args.max_predict_samples is not None:
284 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
285 | predict_dataset = predict_dataset.select(range(max_predict_samples))
286 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
287 | predict_dataset = predict_dataset.map(
288 | preprocess_function_eval,
289 | batched=True,
290 | num_proc=data_args.preprocessing_num_workers,
291 | remove_columns=column_names,
292 | load_from_cache_file=not data_args.overwrite_cache,
293 | desc="Running tokenizer on prediction dataset",
294 | )
295 | print_dataset_example(predict_dataset[0])
296 |
297 | # Data collator
298 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
299 | data_collator = DataCollatorForSeq2Seq(
300 | tokenizer,
301 | model=model,
302 | label_pad_token_id=label_pad_token_id,
303 | pad_to_multiple_of=None,
304 | padding=False
305 | )
306 |
307 | # Metric
308 | def compute_metrics(eval_preds):
309 | preds, labels = eval_preds
310 | if isinstance(preds, tuple):
311 | preds = preds[0]
312 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
313 | if data_args.ignore_pad_token_for_loss:
314 | # Replace -100 in the labels as we can't decode them.
315 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
316 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
317 |
318 | score_dict = {
319 | "rouge-1": [],
320 | "rouge-2": [],
321 | "rouge-l": [],
322 | "bleu-4": []
323 | }
324 | for pred, label in zip(decoded_preds, decoded_labels):
325 | hypothesis = list(jieba.cut(pred))
326 | reference = list(jieba.cut(label))
327 | rouge = Rouge()
328 | scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
329 | result = scores[0]
330 |
331 | for k, v in result.items():
332 | score_dict[k].append(round(v["f"] * 100, 4))
333 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
334 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
335 |
336 | for k, v in score_dict.items():
337 | score_dict[k] = float(np.mean(v))
338 | return score_dict
339 |
340 | # Override the decoding parameters of Seq2SeqTrainer
341 | training_args.generation_max_length = (
342 | training_args.generation_max_length
343 | if training_args.generation_max_length is not None
344 | else data_args.val_max_target_length
345 | )
346 | training_args.generation_num_beams = (
347 | data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
348 | )
349 | # Initialize our Trainer
350 | trainer = Seq2SeqTrainer(
351 | model=model,
352 | args=training_args,
353 | train_dataset=train_dataset if training_args.do_train else None,
354 | eval_dataset=eval_dataset if training_args.do_eval else None,
355 | tokenizer=tokenizer,
356 | data_collator=data_collator,
357 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
358 | save_prefixencoder=model_args.pre_seq_len is not None
359 | )
360 |
361 | # Training
362 | if training_args.do_train:
363 | checkpoint = None
364 | if training_args.resume_from_checkpoint is not None:
365 | checkpoint = training_args.resume_from_checkpoint
366 | # elif last_checkpoint is not None:
367 | # checkpoint = last_checkpoint
368 | model.gradient_checkpointing_enable()
369 | model.enable_input_require_grads()
370 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
371 | # trainer.save_model() # Saves the tokenizer too for easy upload
372 |
373 | metrics = train_result.metrics
374 | max_train_samples = (
375 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
376 | )
377 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
378 |
379 | trainer.log_metrics("train", metrics)
380 | trainer.save_metrics("train", metrics)
381 | trainer.save_state()
382 |
383 | # Evaluation
384 | results = {}
385 | if training_args.do_eval:
386 | logger.info("*** Evaluate ***")
387 | metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=512, temperature=0.95)
388 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
389 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
390 |
391 | trainer.log_metrics("eval", metrics)
392 | trainer.save_metrics("eval", metrics)
393 |
394 | if training_args.do_predict:
395 | logger.info("*** Predict ***")
396 |
397 | predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=512, do_sample=True, top_p=0.7, temperature=0.95)
398 | metrics = predict_results.metrics
399 | max_predict_samples = (
400 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
401 | )
402 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
403 |
404 | trainer.log_metrics("predict", metrics)
405 | trainer.save_metrics("predict", metrics)
406 |
407 | if trainer.is_world_process_zero():
408 | if training_args.predict_with_generate:
409 | predictions = tokenizer.batch_decode(
410 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
411 | )
412 | predictions = [pred.strip() for pred in predictions]
413 | labels = tokenizer.batch_decode(
414 | predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
415 | )
416 | labels = [label.strip() for label in labels]
417 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
418 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
419 | for p, l in zip(predictions, labels):
420 | res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
421 | writer.write(f"{res}\n")
422 | return results
423 |
424 |
425 | def _mp_fn(index):
426 | # For xla_spawn (TPUs)
427 | main()
428 |
429 |
430 | if __name__ == "__main__":
431 | main()
432 |
--------------------------------------------------------------------------------
/ptuning/train.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=2e-2
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_train \
6 | --train_file AdvertiseGen/train.json \
7 | --validation_file AdvertiseGen/dev.json \
8 | --prompt_column content \
9 | --response_column summary \
10 | --overwrite_cache \
11 | --model_name_or_path THUDM/chatglm-6b \
12 | --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
13 | --overwrite_output_dir \
14 | --max_source_length 64 \
15 | --max_target_length 64 \
16 | --per_device_train_batch_size 1 \
17 | --per_device_eval_batch_size 1 \
18 | --gradient_accumulation_steps 16 \
19 | --predict_with_generate \
20 | --max_steps 3000 \
21 | --logging_steps 10 \
22 | --save_steps 1000 \
23 | --learning_rate $LR \
24 | --pre_seq_len $PRE_SEQ_LEN \
25 | --quantization_bit 4
26 |
27 |
--------------------------------------------------------------------------------
/ptuning/train_chat.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 | LR=1e-2
3 |
4 | CUDA_VISIBLE_DEVICES=0 python3 main.py \
5 | --do_train \
6 | --train_file $CHAT_TRAIN_DATA \
7 | --validation_file $CHAT_VAL_DATA \
8 | --prompt_column prompt \
9 | --response_column response \
10 | --history_column history \
11 | --overwrite_cache \
12 | --model_name_or_path THUDM/chatglm-6b \
13 | --output_dir $CHECKPOINT_NAME \
14 | --overwrite_output_dir \
15 | --max_source_length 256 \
16 | --max_target_length 256 \
17 | --per_device_train_batch_size 1 \
18 | --per_device_eval_batch_size 1 \
19 | --gradient_accumulation_steps 16 \
20 | --predict_with_generate \
21 | --max_steps 3000 \
22 | --logging_steps 10 \
23 | --save_steps 1000 \
24 | --learning_rate $LR \
25 | --pre_seq_len $PRE_SEQ_LEN \
26 | --quantization_bit 4
27 |
28 |
--------------------------------------------------------------------------------
/ptuning/trainer_seq2seq.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | from torch import nn
19 | from torch.utils.data import Dataset
20 |
21 | from transformers.deepspeed import is_deepspeed_zero3_enabled
22 | from trainer import Trainer
23 | from transformers.trainer_utils import PredictionOutput
24 | from transformers.utils import logging
25 |
26 |
27 | logger = logging.get_logger(__name__)
28 |
29 |
30 | class Seq2SeqTrainer(Trainer):
31 | def evaluate(
32 | self,
33 | eval_dataset: Optional[Dataset] = None,
34 | ignore_keys: Optional[List[str]] = None,
35 | metric_key_prefix: str = "eval",
36 | **gen_kwargs
37 | ) -> Dict[str, float]:
38 | """
39 | Run evaluation and returns metrics.
40 |
41 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent
42 | (pass it to the init `compute_metrics` argument).
43 |
44 | You can also subclass and override this method to inject custom behavior.
45 |
46 | Args:
47 | eval_dataset (`Dataset`, *optional*):
48 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns
49 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__`
50 | method.
51 | ignore_keys (`List[str]`, *optional*):
52 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
53 | gathering predictions.
54 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
55 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
56 | "eval_bleu" if the prefix is `"eval"` (default)
57 | max_length (`int`, *optional*):
58 | The maximum target length to use when predicting with the generate method.
59 | num_beams (`int`, *optional*):
60 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
61 | beam search.
62 | gen_kwargs:
63 | Additional `generate` specific kwargs.
64 |
65 | Returns:
66 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
67 | dictionary also contains the epoch number which comes from the training state.
68 | """
69 |
70 | gen_kwargs = gen_kwargs.copy()
71 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
72 | gen_kwargs["max_length"] = self.args.generation_max_length
73 | gen_kwargs["num_beams"] = (
74 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
75 | )
76 | self._gen_kwargs = gen_kwargs
77 |
78 | return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
79 |
80 | def predict(
81 | self,
82 | test_dataset: Dataset,
83 | ignore_keys: Optional[List[str]] = None,
84 | metric_key_prefix: str = "test",
85 | **gen_kwargs
86 | ) -> PredictionOutput:
87 | """
88 | Run prediction and returns predictions and potential metrics.
89 |
90 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
91 | will also return metrics, like in `evaluate()`.
92 |
93 | Args:
94 | test_dataset (`Dataset`):
95 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the
96 | `model.forward()` method are automatically removed. Has to implement the method `__len__`
97 | ignore_keys (`List[str]`, *optional*):
98 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when
99 | gathering predictions.
100 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`):
101 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
102 | "eval_bleu" if the prefix is `"eval"` (default)
103 | max_length (`int`, *optional*):
104 | The maximum target length to use when predicting with the generate method.
105 | num_beams (`int`, *optional*):
106 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no
107 | beam search.
108 | gen_kwargs:
109 | Additional `generate` specific kwargs.
110 |
111 |
112 |
113 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic
114 | padding in a token classification task) the predictions will be padded (on the right) to allow for
115 | concatenation into one array. The padding index is -100.
116 |
117 |
118 |
119 | Returns: *NamedTuple* A namedtuple with the following keys:
120 |
121 | - predictions (`np.ndarray`): The predictions on `test_dataset`.
122 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some).
123 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained
124 | labels).
125 | """
126 |
127 | gen_kwargs = gen_kwargs.copy()
128 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
129 | gen_kwargs["max_length"] = self.args.generation_max_length
130 | gen_kwargs["num_beams"] = (
131 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.args.generation_num_beams
132 | )
133 | self._gen_kwargs = gen_kwargs
134 |
135 |
136 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
137 |
138 | def prediction_step(
139 | self,
140 | model: nn.Module,
141 | inputs: Dict[str, Union[torch.Tensor, Any]],
142 | prediction_loss_only: bool,
143 | ignore_keys: Optional[List[str]] = None,
144 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
145 | """
146 | Perform an evaluation step on `model` using `inputs`.
147 |
148 | Subclass and override to inject custom behavior.
149 |
150 | Args:
151 | model (`nn.Module`):
152 | The model to evaluate.
153 | inputs (`Dict[str, Union[torch.Tensor, Any]]`):
154 | The inputs and targets of the model.
155 |
156 | The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
157 | argument `labels`. Check your model's documentation for all accepted arguments.
158 | prediction_loss_only (`bool`):
159 | Whether or not to return the loss only.
160 |
161 | Return:
162 | Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
163 | labels (each being optional).
164 | """
165 |
166 | if not self.args.predict_with_generate or prediction_loss_only:
167 | return super().prediction_step(
168 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
169 | )
170 |
171 | has_labels = "labels" in inputs
172 | inputs = self._prepare_inputs(inputs)
173 |
174 | # XXX: adapt synced_gpus for fairscale as well
175 | gen_kwargs = self._gen_kwargs.copy()
176 | if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
177 | gen_kwargs["max_length"] = self.model.config.max_length
178 | gen_kwargs["num_beams"] = (
179 | gen_kwargs["num_beams"] if gen_kwargs.get("num_beams") is not None else self.model.config.num_beams
180 | )
181 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
182 | gen_kwargs["synced_gpus"] = (
183 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
184 | )
185 |
186 | if "attention_mask" in inputs:
187 | gen_kwargs["attention_mask"] = inputs.get("attention_mask", None)
188 | if "position_ids" in inputs:
189 | gen_kwargs["position_ids"] = inputs.get("position_ids", None)
190 | if "global_attention_mask" in inputs:
191 | gen_kwargs["global_attention_mask"] = inputs.get("global_attention_mask", None)
192 |
193 | # prepare generation inputs
194 | # some encoder-decoder models can have varying encoder's and thus
195 | # varying model input names
196 | if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
197 | generation_inputs = inputs[self.model.encoder.main_input_name]
198 | else:
199 | generation_inputs = inputs[self.model.main_input_name]
200 |
201 | gen_kwargs["input_ids"] = generation_inputs
202 | generated_tokens = self.model.generate(**gen_kwargs)
203 | generated_tokens = generated_tokens[:, generation_inputs.size()[-1]:]
204 |
205 | # in case the batch is shorter than max length, the output should be padded
206 | if gen_kwargs.get("max_length") is not None and generated_tokens.shape[-1] < gen_kwargs["max_length"]:
207 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
208 | elif gen_kwargs.get("max_new_tokens") is not None and generated_tokens.shape[-1] < (
209 | gen_kwargs["max_new_tokens"] + 1
210 | ):
211 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_new_tokens"] + 1)
212 |
213 | loss = None
214 |
215 | if self.args.prediction_loss_only:
216 | return (loss, None, None)
217 |
218 | if has_labels:
219 | labels = inputs["labels"]
220 | if gen_kwargs.get("max_length") is not None and labels.shape[-1] < gen_kwargs["max_length"]:
221 | labels = self._pad_tensors_to_max_len(labels, gen_kwargs["max_length"])
222 | elif gen_kwargs.get("max_new_tokens") is not None and labels.shape[-1] < (
223 | gen_kwargs["max_new_tokens"] + 1
224 | ):
225 | labels = self._pad_tensors_to_max_len(labels, (gen_kwargs["max_new_tokens"] + 1))
226 | else:
227 | labels = None
228 |
229 | return (loss, generated_tokens, labels)
230 |
231 | def _pad_tensors_to_max_len(self, tensor, max_length):
232 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
233 | # If PAD token is not defined at least EOS token has to be defined
234 | pad_token_id = (
235 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
236 | )
237 | else:
238 | if self.model.config.pad_token_id is not None:
239 | pad_token_id = self.model.config.pad_token_id
240 | else:
241 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
242 |
243 | padded_tensor = pad_token_id * torch.ones(
244 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device
245 | )
246 | padded_tensor[:, : tensor.shape[-1]] = tensor
247 | return padded_tensor
248 |
--------------------------------------------------------------------------------
/ptuning/web_demo.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | import gradio as gr
4 | import mdtex2html
5 |
6 | import torch
7 | import transformers
8 | from transformers import (
9 | AutoConfig,
10 | AutoModel,
11 | AutoTokenizer,
12 | AutoTokenizer,
13 | DataCollatorForSeq2Seq,
14 | HfArgumentParser,
15 | Seq2SeqTrainingArguments,
16 | set_seed,
17 | )
18 |
19 | from arguments import ModelArguments, DataTrainingArguments
20 |
21 |
22 | model = None
23 | tokenizer = None
24 |
25 | """Override Chatbot.postprocess"""
26 |
27 |
28 | def postprocess(self, y):
29 | if y is None:
30 | return []
31 | for i, (message, response) in enumerate(y):
32 | y[i] = (
33 | None if message is None else mdtex2html.convert((message)),
34 | None if response is None else mdtex2html.convert(response),
35 | )
36 | return y
37 |
38 |
39 | gr.Chatbot.postprocess = postprocess
40 |
41 |
42 | def parse_text(text):
43 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
44 | lines = text.split("\n")
45 | lines = [line for line in lines if line != ""]
46 | count = 0
47 | for i, line in enumerate(lines):
48 | if "```" in line:
49 | count += 1
50 | items = line.split('`')
51 | if count % 2 == 1:
52 | lines[i] = f''
53 | else:
54 | lines[i] = f'
'
55 | else:
56 | if i > 0:
57 | if count % 2 == 1:
58 | line = line.replace("`", "\`")
59 | line = line.replace("<", "<")
60 | line = line.replace(">", ">")
61 | line = line.replace(" ", " ")
62 | line = line.replace("*", "*")
63 | line = line.replace("_", "_")
64 | line = line.replace("-", "-")
65 | line = line.replace(".", ".")
66 | line = line.replace("!", "!")
67 | line = line.replace("(", "(")
68 | line = line.replace(")", ")")
69 | line = line.replace("$", "$")
70 | lines[i] = " "+line
71 | text = "".join(lines)
72 | return text
73 |
74 |
75 | def predict(input, chatbot, max_length, top_p, temperature, history):
76 | chatbot.append((parse_text(input), ""))
77 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
78 | temperature=temperature):
79 | chatbot[-1] = (parse_text(input), parse_text(response))
80 |
81 | yield chatbot, history
82 |
83 |
84 | def reset_user_input():
85 | return gr.update(value='')
86 |
87 |
88 | def reset_state():
89 | return [], []
90 |
91 |
92 | with gr.Blocks() as demo:
93 | gr.HTML("""ChatGLM """)
94 |
95 | chatbot = gr.Chatbot()
96 | with gr.Row():
97 | with gr.Column(scale=4):
98 | with gr.Column(scale=12):
99 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
100 | container=False)
101 | with gr.Column(min_width=32, scale=1):
102 | submitBtn = gr.Button("Submit", variant="primary")
103 | with gr.Column(scale=1):
104 | emptyBtn = gr.Button("Clear History")
105 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
106 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
107 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
108 |
109 | history = gr.State([])
110 |
111 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
112 | show_progress=True)
113 | submitBtn.click(reset_user_input, [], [user_input])
114 |
115 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
116 |
117 |
118 |
119 | def main():
120 | global model, tokenizer
121 |
122 | parser = HfArgumentParser((
123 | ModelArguments))
124 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
125 | # If we pass only one argument to the script and it's the path to a json file,
126 | # let's parse it to get our arguments.
127 | model_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))[0]
128 | else:
129 | model_args = parser.parse_args_into_dataclasses()[0]
130 |
131 | tokenizer = AutoTokenizer.from_pretrained(
132 | model_args.model_name_or_path, trust_remote_code=True)
133 | config = AutoConfig.from_pretrained(
134 | model_args.model_name_or_path, trust_remote_code=True)
135 |
136 | config.pre_seq_len = model_args.pre_seq_len
137 | config.prefix_projection = model_args.prefix_projection
138 |
139 | if model_args.ptuning_checkpoint is not None:
140 | print(f"Loading prefix_encoder weight from {model_args.ptuning_checkpoint}")
141 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
142 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
143 | new_prefix_state_dict = {}
144 | for k, v in prefix_state_dict.items():
145 | if k.startswith("transformer.prefix_encoder."):
146 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
147 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
148 | else:
149 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
150 |
151 | if model_args.quantization_bit is not None:
152 | print(f"Quantized to {model_args.quantization_bit} bit")
153 | model = model.quantize(model_args.quantization_bit)
154 |
155 | if model_args.pre_seq_len is not None:
156 | # P-tuning v2
157 | model = model.half().cuda()
158 | model.transformer.prefix_encoder.float().cuda()
159 |
160 | model = model.eval()
161 | demo.queue().launch(share=False, inbrowser=True)
162 |
163 |
164 |
165 | if __name__ == "__main__":
166 | main()
--------------------------------------------------------------------------------
/ptuning/web_demo.sh:
--------------------------------------------------------------------------------
1 | PRE_SEQ_LEN=128
2 |
3 | CUDA_VISIBLE_DEVICES=0 python3 web_demo.py \
4 | --model_name_or_path THUDM/chatglm-6b \
5 | --ptuning_checkpoint output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000 \
6 | --pre_seq_len $PRE_SEQ_LEN
7 |
8 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf
2 | transformers==4.27.1
3 | cpm_kernels
4 | torch>=1.10
5 | gradio
6 | mdtex2html
7 | sentencepiece
8 | accelerate
9 | paddlepaddle
10 | uvicorn
11 | requests
12 | fastapi
13 |
--------------------------------------------------------------------------------
/resources/WECHAT.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
扫码关注公众号,加入「ChatGLM交流群」
5 |
Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"
6 |
7 |
8 |
--------------------------------------------------------------------------------
/resources/cli-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/resources/cli-demo.png
--------------------------------------------------------------------------------
/resources/web-demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/resources/web-demo.gif
--------------------------------------------------------------------------------
/resources/web-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/resources/web-demo.png
--------------------------------------------------------------------------------
/resources/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/luxinfeng/audioConversation-ChatGLM/914af6d620992f0a1a9d06097787a553f0e7f592/resources/wechat.jpg
--------------------------------------------------------------------------------
/start.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # 启动model_api.py
4 | echo "正在启动model_api.py ..."
5 | python model_api.py &
6 |
7 | # 等待5秒钟,确保model_api.py已经启动完成
8 | sleep 30
9 |
10 | # 启动web_ui.py
11 | echo "正在启动web_ui.py ..."
12 | python web_ui.py &
13 |
14 | # 等待5秒钟,确保web_ui.py已经启动完成
15 | sleep 30
16 |
17 | # 输出提示
18 | echo "启动成功!"
19 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Tuple, Union, Optional
3 |
4 | from torch.nn import Module
5 | from transformers import AutoModel
6 |
7 |
8 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
9 | # transformer.word_embeddings 占用1层
10 | # transformer.final_layernorm 和 lm_head 占用1层
11 | # transformer.layers 占用 28 层
12 | # 总共30层分配到num_gpus张卡上
13 | num_trans_layers = 28
14 | per_gpu_layers = 30 / num_gpus
15 |
16 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
17 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
18 | # linux下 model.device 会被设置成 lm_head.device
19 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
20 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
21 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
22 | device_map = {'transformer.word_embeddings': 0,
23 | 'transformer.final_layernorm': 0, 'lm_head': 0}
24 |
25 | used = 2
26 | gpu_target = 0
27 | for i in range(num_trans_layers):
28 | if used >= per_gpu_layers:
29 | gpu_target += 1
30 | used = 0
31 | assert gpu_target < num_gpus
32 | device_map[f'transformer.layers.{i}'] = gpu_target
33 | used += 1
34 |
35 | return device_map
36 |
37 |
38 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
39 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
40 | if num_gpus < 2 and device_map is None:
41 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
42 | else:
43 | from accelerate import dispatch_model
44 |
45 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
46 |
47 | if device_map is None:
48 | device_map = auto_configure_device_map(num_gpus)
49 |
50 | model = dispatch_model(model, device_map=device_map)
51 |
52 | return model
53 |
54 |
55 |
--------------------------------------------------------------------------------
/web_demo.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | import gradio as gr
3 | import mdtex2html
4 |
5 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
6 | model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()
7 | model = model.eval()
8 |
9 | """Override Chatbot.postprocess"""
10 |
11 |
12 | def postprocess(self, y):
13 | if y is None:
14 | return []
15 | for i, (message, response) in enumerate(y):
16 | y[i] = (
17 | None if message is None else mdtex2html.convert((message)),
18 | None if response is None else mdtex2html.convert(response),
19 | )
20 | return y
21 |
22 |
23 | gr.Chatbot.postprocess = postprocess
24 |
25 |
26 | def parse_text(text):
27 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
28 | lines = text.split("\n")
29 | lines = [line for line in lines if line != ""]
30 | count = 0
31 | for i, line in enumerate(lines):
32 | if "```" in line:
33 | count += 1
34 | items = line.split('`')
35 | if count % 2 == 1:
36 | lines[i] = f''
37 | else:
38 | lines[i] = f'
'
39 | else:
40 | if i > 0:
41 | if count % 2 == 1:
42 | line = line.replace("`", "\`")
43 | line = line.replace("<", "<")
44 | line = line.replace(">", ">")
45 | line = line.replace(" ", " ")
46 | line = line.replace("*", "*")
47 | line = line.replace("_", "_")
48 | line = line.replace("-", "-")
49 | line = line.replace(".", ".")
50 | line = line.replace("!", "!")
51 | line = line.replace("(", "(")
52 | line = line.replace(")", ")")
53 | line = line.replace("$", "$")
54 | lines[i] = " "+line
55 | text = "".join(lines)
56 | return text
57 |
58 |
59 | def predict(input, chatbot, max_length, top_p, temperature, history):
60 | chatbot.append((parse_text(input), ""))
61 | for response, history in model.stream_chat(tokenizer, input, history, max_length=max_length, top_p=top_p,
62 | temperature=temperature):
63 | chatbot[-1] = (parse_text(input), parse_text(response))
64 |
65 | yield chatbot, history
66 |
67 |
68 | def reset_user_input():
69 | return gr.update(value='')
70 |
71 |
72 | def reset_state():
73 | return [], []
74 |
75 |
76 | with gr.Blocks() as demo:
77 | gr.HTML("""ChatGLM """)
78 |
79 | chatbot = gr.Chatbot()
80 | with gr.Row():
81 | with gr.Column(scale=4):
82 | with gr.Column(scale=12):
83 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
84 | container=False)
85 | with gr.Column(min_width=32, scale=1):
86 | submitBtn = gr.Button("Submit", variant="primary")
87 | with gr.Column(scale=1):
88 | emptyBtn = gr.Button("Clear History")
89 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
90 | top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
91 | temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
92 |
93 | history = gr.State([])
94 |
95 | submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
96 | show_progress=True)
97 | submitBtn.click(reset_user_input, [], [user_input])
98 |
99 | emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
100 |
101 | demo.queue().launch(share=False, inbrowser=True)
102 |
--------------------------------------------------------------------------------
/web_index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | Voice Chat
8 |
58 |
59 |
60 |
61 |
Voice Chat
62 |
Record
63 |
Stop
64 |
65 |
66 |
132 |
133 |
134 |
--------------------------------------------------------------------------------
/web_ui.py:
--------------------------------------------------------------------------------
1 | import json
2 | import subprocess
3 |
4 | from paddlespeech.cli.asr.infer import ASRExecutor
5 | from paddlespeech.cli.tts.infer import TTSExecutor
6 | from fastapi import FastAPI, File, UploadFile
7 | from fastapi.responses import FileResponse
8 | from fastapi.middleware.cors import CORSMiddleware
9 |
10 |
11 | import requests
12 | import uvicorn
13 | import paddle
14 |
15 | origins = [
16 | '*'
17 | ]
18 |
19 |
20 | app = FastAPI()
21 | app.add_middleware(
22 | CORSMiddleware,
23 | allow_origins=origins,
24 | allow_credentials=True,
25 | allow_methods=["*"],
26 | allow_headers=["*"],
27 | )
28 | text2audio = TTSExecutor()
29 | audio2text = ASRExecutor()
30 |
31 |
32 |
33 | @app.post("/chat")
34 | async def chat(audio_file: UploadFile):
35 | """处理语音聊天请求"""
36 |
37 | print(audio_file.filename)
38 |
39 | """将上传的音频数据保存到临时文件"""
40 | with open("tmp.webm", "wb") as f:
41 | f.write(await audio_file.read())
42 |
43 | transformAudio('tmp.webm', 'tmp.wav')
44 |
45 | # 识别语音数据
46 | request_text = recognize_speech("tmp.wav")
47 |
48 | response_text = generateResponse(request_text)
49 |
50 | # 合成回复语音数据
51 | speech_data = generate_speech(response_text)
52 |
53 | # 返回回复语音数据
54 | return FileResponse(speech_data, media_type="audio/wav")
55 |
56 |
57 |
58 | def generateResponse(message):
59 | url = "http://127.0.0.1:8000/"
60 |
61 | payload = {"prompt": message, "history": [["please speak in english, no more than 15 words", "ok"]]}
62 | headers = {"Content-Type": "application/json"}
63 |
64 | response = requests.post(url, json=payload, headers=headers)
65 |
66 | response_text = response.text
67 |
68 | data2 = json.loads(response_text)
69 |
70 | print(data2['response'])
71 |
72 | return data2['response']
73 |
74 |
75 | def recognize_speech(audio_data):
76 | result = audio2text(audio_file=audio_data, lang='en', device=paddle.get_device(), model='transformer_librispeech')
77 | return result
78 |
79 |
80 | def generate_speech(message_text):
81 | message_audio = text2audio(message_text, lang='en', am = 'fastspeech2_vctk', voc='hifigan_vctk', device=paddle.get_device(), spk_id=7)
82 | return message_audio
83 |
84 |
85 | def transformAudio(web_file, wav_file):
86 | command = ['rm', '-rf', wav_file]
87 | subprocess.run(command)
88 | command = ['ffmpeg', '-i', web_file, '-ac', '1', '-ar', '16000', wav_file]
89 | subprocess.run(command, stdout=subprocess.PIPE, stdin=subprocess.PIPE)
90 |
91 |
92 | if __name__ == '__main__':
93 | uvicorn.run(app, host='0.0.0.0', port=7060, workers=1)
94 |
--------------------------------------------------------------------------------