├── LICENSE.txt
├── MODEL_LICENSE.txt
├── README.md
├── __init__.py
├── api.py
├── api_hf.py
├── checkpoints
├── 300
│ └── 需要下载pt文件.txt
├── latest
└── model_config.json
├── cli_demo.py
├── cli_demo_hf.py
├── examples
├── 1.jpeg
├── 2.jpeg
├── 3.jpeg
├── chat_example1.png
├── chat_example2.png
├── chat_example3.png
├── example_inputs.jsonl
├── thu.png
└── web_demo.png
├── fewshot-data
├── 2p.png
├── dataset.json
├── ghost.jpg
├── justice.png
├── katong.png
├── kobe.png
├── man.jpg
├── meme.png
├── music.png
├── panda.jpg
├── passport.png
├── pattern.png
├── pig.png
├── push.png
├── rou.png
├── rub.png
├── tianye.png
├── titan.png
├── tower.png
├── traf.png
└── woman.png
├── finetune.py
├── finetune_visualglm.py
├── finetune_visualglm.sh
├── img.png
├── lora_mixin.py
├── model
├── __init__.py
├── blip2.py
├── chat.py
├── infer_util.py
└── visualglm.py
├── predict.py
├── predict_lora.py
├── requirements.txt
├── requirements_wo_ds.txt
├── web_demo.py
└── web_demo_hf.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright Zhengxiao Du
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/MODEL_LICENSE.txt:
--------------------------------------------------------------------------------
1 | The GLM-130B License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means the VisualGLM-6B Model Team that distributes its Software.
6 |
7 | “Software” means the VisualGLM-6B model parameters made available under this license.
8 |
9 | 2. License Grant
10 |
11 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes.
12 |
13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
14 |
15 | 3. Restriction
16 |
17 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes.
18 |
19 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
20 |
21 | 4. Disclaimer
22 |
23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
24 |
25 | 5. Limitation of Liability
26 |
27 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
28 |
29 | 6. Dispute Resolution
30 |
31 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
32 |
33 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Visualglm-image-to-text
2 | ##介绍
3 |
4 | 使用了清华的Visualglm语言模型进行lora finetune,做了个简单的面相预测
5 |
6 | 补充了训练需要的一些文件,补全训练缺少latest等文件问题
7 |
8 | 补充一些训练和预测代码。
9 |
10 | ## 样例
11 | 对面相进行预测
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 | ## 使用
21 |
22 | ### 模型推理
23 |
24 | 使用pip安装依赖
25 | ```
26 | pip install -r requirements.txt
27 | ```
28 | 尽量使用标准PyPI源以下载较新的sat包,TUNA源等可能同步较慢。`pip install -i https://pypi.org/simple -r requirements.txt`。
29 | 此时默认会安装`deepspeed`库(支持`sat`库训练),此库对于模型推理并非必要,同时部分Windows环境安装此库时会遇到问题。如果想绕过`deepspeed`安装,我们可以将命令改为
30 | ```
31 | pip install -r requirements_wo_ds.txt
32 | pip install --no-deps "SwissArmyTransformer>=0.3.6"
33 | ```
34 |
35 | ### 模型微调
36 | 1,需要在checkpints/300目录下下载mp_rank_00_model_states.pt文件,获取途径如下
37 |
38 | wget https://huggingface.co/wangrongsheng/XrayGLM-300/resolve/main/300/mp_rank_00_model_states.pt
39 |
40 | 2,visualglm-6b 文件下载路径如下
41 |
42 | [https://huggingface.co/THUDM/visualglm-6b/tree/main]
43 |
44 | 然后执行
45 | bash finetune_visualglm.sh
46 |
47 | ### 模型预测
48 | 使用原模型:python predict.py
49 |
50 | 使用finetune后模型:python predict_lora.py
51 |
52 | ##注意文件里路径都需要注意做相应调整,改到自己目录下。
53 |
54 | # Visualglm-image-to-text
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/__init__.py
--------------------------------------------------------------------------------
/api.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import uvicorn
4 | from fastapi import FastAPI, Request
5 | from model import is_chinese, get_infer_setting, generate_input, chat
6 | import datetime
7 |
8 | gpu_number = 0
9 | model, tokenizer = get_infer_setting(gpu_device=gpu_number)
10 |
11 | app = FastAPI()
12 | @app.post('/')
13 | async def visual_glm(request: Request):
14 | json_post_raw = await request.json()
15 | print("Start to process request")
16 |
17 | json_post = json.dumps(json_post_raw)
18 | request_data = json.loads(json_post)
19 | input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data['history']
20 | input_para = {
21 | "max_length": 2048,
22 | "min_length": 50,
23 | "temperature": 0.8,
24 | "top_p": 0.4,
25 | "top_k": 100,
26 | "repetition_penalty": 1.2
27 | }
28 | input_para.update(request_data)
29 |
30 | is_zh = is_chinese(input_text)
31 | input_data = generate_input(input_text, input_image_encoded, history, input_para)
32 | input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
33 | answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
34 | max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
35 | top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
36 |
37 | now = datetime.datetime.now()
38 | time = now.strftime("%Y-%m-%d %H:%M:%S")
39 | response = {
40 | "result": answer,
41 | "history": history,
42 | "status": 200,
43 | "time": time
44 | }
45 | return response
46 |
47 |
48 | if __name__ == '__main__':
49 | uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
--------------------------------------------------------------------------------
/api_hf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from transformers import AutoTokenizer, AutoModel
4 | import uvicorn
5 | from fastapi import FastAPI, Request
6 | import datetime
7 | from model import process_image
8 |
9 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
10 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
11 |
12 |
13 | app = FastAPI()
14 | @app.post('/')
15 | async def visual_glm(request: Request):
16 | json_post_raw = await request.json()
17 | print("Start to process request")
18 |
19 | json_post = json.dumps(json_post_raw)
20 | request_data = json.loads(json_post)
21 |
22 | history = request_data.get("history")
23 | image_encoded = request_data.get("image")
24 | query = request_data.get("text")
25 | image_path = process_image(image_encoded)
26 |
27 | result = model.stream_chat(tokenizer, image_path, query, history=history)
28 | last_result = None
29 | for value in result:
30 | last_result = value
31 | answer = last_result[0]
32 |
33 | if os.path.isfile(image_path):
34 | os.remove(image_path)
35 | now = datetime.datetime.now()
36 | time = now.strftime("%Y-%m-%d %H:%M:%S")
37 | response = {
38 | "result": answer,
39 | "history": history,
40 | "status": 200,
41 | "time": time
42 | }
43 | return response
44 |
45 |
46 | if __name__ == "__main__":
47 | uvicorn.run(app, host='0.0.0.0', port=8080, workers=1)
--------------------------------------------------------------------------------
/checkpoints/300/需要下载pt文件.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/checkpoints/300/需要下载pt文件.txt
--------------------------------------------------------------------------------
/checkpoints/latest:
--------------------------------------------------------------------------------
1 | 300
--------------------------------------------------------------------------------
/checkpoints/model_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_class": "FineTuneVisualGLMModel",
3 | "tokenizer_type": "THUDM/chatglm-6b",
4 | "num_layers": 28,
5 | "hidden_size": 4096,
6 | "num_attention_heads": 32,
7 | "vocab_size": 130528,
8 | "layernorm_order": "post",
9 | "model_parallel_size": 1,
10 | "max_sequence_length": 2048,
11 | "pre_seq_len": 128,
12 | "lora_rank": 10,
13 | "use_ptuning": false,
14 | "use_lora": true,
15 | "image_length": 32,
16 | "eva_args": {
17 | "num_layers": 39,
18 | "hidden_size": 1408,
19 | "num_attention_heads": 16,
20 | "vocab_size": 1,
21 | "layernorm_order": "pre",
22 | "model_parallel_size": 1,
23 | "max_sequence_length": 257,
24 | "inner_hidden_size": 6144,
25 | "use_final_layernorm": false,
26 | "layernorm_epsilon": 1e-06,
27 | "image_size": [
28 | 224,
29 | 224
30 | ],
31 | "pre_len": 1,
32 | "post_len": 0,
33 | "in_channels": 3,
34 | "num_classes": 0,
35 | "patch_size": 14
36 | },
37 | "qformer_args": {
38 | "num_layers": 12,
39 | "hidden_size": 768,
40 | "num_attention_heads": 12,
41 | "vocab_size": 32,
42 | "layernorm_order": "post",
43 | "model_parallel_size": 1,
44 | "max_sequence_length": 0,
45 | "is_decoder": [
46 | true,
47 | false,
48 | true,
49 | false,
50 | true,
51 | false,
52 | true,
53 | false,
54 | true,
55 | false,
56 | true,
57 | false
58 | ],
59 | "cross_attn_hidden_size": 1408,
60 | "layernorm_epsilon": 1e-12
61 | },
62 | "bos_token_id": 130004,
63 | "mask_token_id": 130000,
64 | "gmask_token_id": 130001,
65 | "image_size": [
66 | 224,
67 | 224
68 | ],
69 | "pre_len": 1,
70 | "post_len": 0,
71 | "in_channels": 3,
72 | "patch_size": 14
73 | }
--------------------------------------------------------------------------------
/cli_demo.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 |
3 | import os
4 | import sys
5 | import torch
6 | import argparse
7 | from transformers import AutoTokenizer
8 | from sat.model.mixins import CachedAutoregressiveMixin
9 | from sat.quantization.kernels import quantize
10 |
11 | from model import VisualGLMModel, chat
12 | from finetune_visualglm import FineTuneVisualGLMModel
13 | from sat.model import AutoModel
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
19 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
20 | parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
21 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
22 | parser.add_argument("--english", action='store_true', help='only output English')
23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
24 | parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt')
25 | parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round')
26 | parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round')
27 | args = parser.parse_args()
28 |
29 | # load model
30 | model, model_args = AutoModel.from_pretrained(
31 | args.from_pretrained,
32 | args=argparse.Namespace(
33 | fp16=True,
34 | skip_init=True,
35 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
36 | device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu',
37 | ))
38 | model = model.eval()
39 |
40 | if args.quant:
41 | quantize(model.transformer, args.quant)
42 |
43 | if torch.cuda.is_available():
44 | model = model.cuda()
45 |
46 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
47 |
48 | tokenizer = AutoTokenizer.from_pretrained("../chatglm-6b", trust_remote_code=True)
49 | if not args.english:
50 | print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
51 | else:
52 | print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
53 | with torch.no_grad():
54 | while True:
55 | history = None
56 | cache_image = None
57 | if not args.english:
58 | image_path = input("请输入图像路径或URL(回车进入纯文本对话): ")
59 | else:
60 | image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ")
61 |
62 | if image_path == 'stop':
63 | break
64 | if len(image_path) > 0:
65 | query = args.prompt_en if args.english else args.prompt_zh
66 | else:
67 | if not args.english:
68 | query = input("用户:")
69 | else:
70 | query = input("User: ")
71 | while True:
72 | if query == "clear":
73 | break
74 | if query == "stop":
75 | sys.exit(0)
76 | try:
77 | response, history, cache_image = chat(
78 | image_path,
79 | model,
80 | tokenizer,
81 | query,
82 | history=history,
83 | image=cache_image,
84 | max_length=args.max_length,
85 | top_p=args.top_p,
86 | temperature=args.temperature,
87 | top_k=args.top_k,
88 | english=args.english,
89 | invalid_slices=[slice(63823, 130000)] if args.english else []
90 | )
91 | except Exception as e:
92 | print(e)
93 | break
94 | sep = 'A:' if args.english else '答:'
95 | print("VisualGLM-6B:"+response.split(sep)[-1].strip())
96 | image_path = None
97 | if not args.english:
98 | query = input("用户:")
99 | else:
100 | query = input("User: ")
101 |
102 |
103 | if __name__ == "__main__":
104 | main()
--------------------------------------------------------------------------------
/cli_demo_hf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import signal
4 | from transformers import AutoTokenizer, AutoModel
5 |
6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
7 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
8 | model = model.eval()
9 |
10 | os_name = platform.system()
11 | clear_command = 'cls' if os_name == 'Windows' else 'clear'
12 | stop_stream = False
13 |
14 |
15 | def build_prompt(history, prefix):
16 | prompt = prefix
17 | for query, response in history:
18 | prompt += f"\n\n用户:{query}"
19 | prompt += f"\n\nVisualGLM-6B:{response}"
20 | return prompt
21 |
22 |
23 | def signal_handler(signal, frame):
24 | global stop_stream
25 | stop_stream = True
26 |
27 |
28 | def main():
29 | global stop_stream
30 | while True:
31 | history = []
32 | prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序"
33 | print(prefix)
34 | image_path = input("\n请输入图片路径:")
35 | if image_path == "stop":
36 | break
37 | prefix = prefix + "\n" + image_path
38 | query = "描述这张图片。"
39 | while True:
40 | count = 0
41 | for response, history in model.stream_chat(tokenizer, image_path, query, history=history):
42 | if stop_stream:
43 | stop_stream = False
44 | break
45 | else:
46 | count += 1
47 | if count % 8 == 0:
48 | os.system(clear_command)
49 | print(build_prompt(history, prefix), flush=True)
50 | signal.signal(signal.SIGINT, signal_handler)
51 | os.system(clear_command)
52 | print(build_prompt(history, prefix), flush=True)
53 | query = input("\n用户:")
54 | if query.strip() == "clear":
55 | break
56 | if query.strip() == "stop":
57 | stop_stream = True
58 | exit(0)
59 | # if query.strip() == "clear":
60 | # history = []
61 | # os.system(clear_command)
62 | # print(prefix)
63 | # continue
64 |
65 |
66 | if __name__ == "__main__":
67 | main()
--------------------------------------------------------------------------------
/examples/1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/1.jpeg
--------------------------------------------------------------------------------
/examples/2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/2.jpeg
--------------------------------------------------------------------------------
/examples/3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/3.jpeg
--------------------------------------------------------------------------------
/examples/chat_example1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example1.png
--------------------------------------------------------------------------------
/examples/chat_example2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example2.png
--------------------------------------------------------------------------------
/examples/chat_example3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example3.png
--------------------------------------------------------------------------------
/examples/example_inputs.jsonl:
--------------------------------------------------------------------------------
1 | {"id":1, "text": "描述一下这个场景", "image": "examples/1.jpeg"}
2 | {"id":2, "text": "这是什么东西", "image": "examples/2.jpeg"}
3 | {"id":3, "text": "这张图片描述了什么", "image": "examples/3.jpeg"}
--------------------------------------------------------------------------------
/examples/thu.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/thu.png
--------------------------------------------------------------------------------
/examples/web_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/web_demo.png
--------------------------------------------------------------------------------
/fewshot-data/2p.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/2p.png
--------------------------------------------------------------------------------
/fewshot-data/dataset.json:
--------------------------------------------------------------------------------
1 | [
2 | {"img": "fewshot-data/2p.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蒙蒙细雨。"},
3 | {"img": "fewshot-data/pig.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是是虚化的。"},
4 | {"img": "fewshot-data/meme.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蓝色的木质地板。"},
5 | {"img": "fewshot-data/passport.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是棕黄色木质桌子。"},
6 | {"img": "fewshot-data/tower.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是黄昏的天空、云彩和繁华的城市高楼。"},
7 | {"img": "fewshot-data/rub.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是太阳、大树、蓝天白云。"},
8 | {"img": "fewshot-data/push.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蓝天和沙漠。"},
9 | {"img": "fewshot-data/traf.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是城市街道。"},
10 | {"img": "fewshot-data/music.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一个音乐混音器。"},
11 | {"img": "fewshot-data/pattern.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是小区的楼房和街道。"},
12 | {"img": "fewshot-data/rou.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是大理石桌子和一个盘子。"},
13 | {"img": "fewshot-data/katong.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是绿色的草地。"},
14 | {"img": "fewshot-data/man.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是城市的街道和高楼。"},
15 | {"img": "fewshot-data/kobe.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是虚化的观众席。"},
16 | {"img": "fewshot-data/panda.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是纯白的。"},
17 | {"img": "fewshot-data/titan.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一座雕像。"},
18 | {"img": "fewshot-data/woman.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是纯蓝的。"},
19 | {"img": "fewshot-data/ghost.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一个房间。"},
20 | {"img": "fewshot-data/justice.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是天空和阳光。"},
21 | {"img": "fewshot-data/tianye.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是金黄的田野。"}
22 | ]
--------------------------------------------------------------------------------
/fewshot-data/ghost.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/ghost.jpg
--------------------------------------------------------------------------------
/fewshot-data/justice.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/justice.png
--------------------------------------------------------------------------------
/fewshot-data/katong.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/katong.png
--------------------------------------------------------------------------------
/fewshot-data/kobe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/kobe.png
--------------------------------------------------------------------------------
/fewshot-data/man.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/man.jpg
--------------------------------------------------------------------------------
/fewshot-data/meme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/meme.png
--------------------------------------------------------------------------------
/fewshot-data/music.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/music.png
--------------------------------------------------------------------------------
/fewshot-data/panda.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/panda.jpg
--------------------------------------------------------------------------------
/fewshot-data/passport.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/passport.png
--------------------------------------------------------------------------------
/fewshot-data/pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/pattern.png
--------------------------------------------------------------------------------
/fewshot-data/pig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/pig.png
--------------------------------------------------------------------------------
/fewshot-data/push.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/push.png
--------------------------------------------------------------------------------
/fewshot-data/rou.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/rou.png
--------------------------------------------------------------------------------
/fewshot-data/rub.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/rub.png
--------------------------------------------------------------------------------
/fewshot-data/tianye.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/tianye.png
--------------------------------------------------------------------------------
/fewshot-data/titan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/titan.png
--------------------------------------------------------------------------------
/fewshot-data/tower.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/tower.png
--------------------------------------------------------------------------------
/fewshot-data/traf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/traf.png
--------------------------------------------------------------------------------
/fewshot-data/woman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/woman.png
--------------------------------------------------------------------------------
/finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer
3 | from model import chat
4 | from model import VisualGLMModel
5 | from sat.model.mixins import CachedAutoregressiveMixin
6 |
7 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True)
8 | model, model_args = VisualGLMModel.from_pretrained("../visualglm-6b",args=argparse.Namespace(fp16=True, skip_init=True))
9 |
10 | model.add_mixin('auto_regressive', CachedAutoregressiveMixin)
11 | image_path='./fewshot-data/龙眼.jpeg'
12 | promote='描述这张图片'
13 | responses,history,cache_image= chat(image_path,model,tokenizer,promote,history=[])
14 | print(responses)
--------------------------------------------------------------------------------
/finetune_visualglm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 |
5 | from sat import mpu, get_args, get_tokenizer
6 | from sat.training.deepspeed_training import training_main
7 | from model import VisualGLMModel
8 | from sat.model.finetune import PTuningV2Mixin
9 | from sat.model.finetune.lora_mixin import LoraMixin
10 | from transformers import AutoTokenizer, AutoModel
11 | from model import VisualGLMModel
12 |
13 |
14 | class FineTuneVisualGLMModel(VisualGLMModel):
15 | def __init__(self, args, transformer=None, parallel_output=True, **kw_args):
16 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args)
17 | if args.use_ptuning:
18 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads,
19 | args.num_attention_heads, args.pre_seq_len))
20 | if args.use_lora:
21 | # If you use lora on other "normal" Transformer, just use it with head_first=False (by default)
22 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, head_first=True,
23 | num_attention_heads=args.num_attention_heads,
24 | hidden_size_per_attention_head=args.hidden_size // args.num_attention_heads,
25 | layer_range=list(range(0, 28, 14))), reinit=True)
26 | # self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank)
27 | self.args = args
28 |
29 | @classmethod
30 | def add_model_specific_args(cls, parser):
31 | group = parser.add_argument_group('VisualGLM-finetune', 'VisualGLM finetune Configurations')
32 | group.add_argument('--pre_seq_len', type=int, default=8)
33 | group.add_argument('--lora_rank', type=int, default=10)
34 | group.add_argument('--use_ptuning', action="store_true")
35 | group.add_argument('--use_lora', action="store_true")
36 | return super().add_model_specific_args(parser)
37 |
38 | def disable_untrainable_params(self):
39 | enable = []
40 | if self.args.use_ptuning:
41 | enable.extend(['ptuning'])
42 | if self.args.use_lora:
43 | enable.extend(['matrix_A', 'matrix_B'])
44 | for n, p in self.named_parameters():
45 | flag = False
46 | for e in enable:
47 | if e.lower() in n.lower():
48 | flag = True
49 | break
50 | if not flag:
51 | p.requires_grad_(False)
52 | else:
53 | print(n)
54 |
55 |
56 | def get_batch(data_iterator, args, timers):
57 | # Items and their type.
58 | keys = ['input_ids', 'labels']
59 | datatype = torch.int64
60 |
61 | # Broadcast data.
62 | timers('data loader').start()
63 | if data_iterator is not None:
64 | data = next(data_iterator)
65 | else:
66 | data = None
67 | timers('data loader').stop()
68 | data_b = mpu.broadcast_data(keys, data, datatype)
69 | data_i = mpu.broadcast_data(['image'], data, torch.float32)
70 | # Unpack.
71 | tokens = data_b['input_ids'].long()
72 | labels = data_b['labels'].long()
73 | img = data_i['image']
74 | if args.fp16:
75 | img = img.half()
76 |
77 | return tokens, labels, img, data['pre_image']
78 |
79 |
80 | from torch.nn import CrossEntropyLoss
81 |
82 |
83 | def forward_step(data_iterator, model, args, timers):
84 | """Forward step."""
85 |
86 | # Get the batch.
87 | timers('batch generator').start()
88 | tokens, labels, image, pre_image = get_batch(
89 | data_iterator, args, timers)
90 | timers('batch generator').stop()
91 |
92 | logits = model(input_ids=tokens, image=image, pre_image=pre_image)[0]
93 | dtype = logits.dtype
94 | lm_logits = logits.to(torch.float32)
95 |
96 | # Shift so that tokens < n predict n
97 | shift_logits = lm_logits[..., :-1, :].contiguous()
98 | shift_labels = labels[..., 1:].contiguous()
99 | # Flatten the tokens
100 | loss_fct = CrossEntropyLoss(ignore_index=-100)
101 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
102 |
103 | lm_logits = lm_logits.to(dtype)
104 | loss = loss.to(dtype)
105 | return loss, {'loss': loss}
106 |
107 |
108 | from model.blip2 import BlipImageEvalProcessor
109 | from torch.utils.data import Dataset
110 | import json
111 | from PIL import Image
112 |
113 |
114 | class FewShotDataset(Dataset):
115 | def __init__(self, path, processor, tokenizer, args):
116 | max_seq_length = args.max_source_length + args.max_target_length
117 | with open(path, 'r', encoding='utf-8') as f:
118 | data = json.load(f)
119 | self.images = []
120 | self.input_ids = []
121 | self.labels = []
122 | for item in data:
123 | image = processor(Image.open(item['img']).convert('RGB'))
124 | input0 = tokenizer.encode("
", add_special_tokens=False)
125 | input1 = [tokenizer.pad_token_id] * args.image_length
126 | input2 = tokenizer.encode("问:" + item['prompt'] + "\n答:", add_special_tokens=False)
127 | a_ids = sum([input0, input1, input2], [])
128 | b_ids = tokenizer.encode(text=item['label'], add_special_tokens=False)
129 | if len(a_ids) > args.max_source_length - 1:
130 | a_ids = a_ids[: args.max_source_length - 1]
131 | if len(b_ids) > args.max_target_length - 2:
132 | b_ids = b_ids[: args.max_target_length - 2]
133 | pre_image = len(input0)
134 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)
135 |
136 | context_length = input_ids.index(tokenizer.bos_token_id)
137 | mask_position = context_length - 1
138 | labels = [-100] * context_length + input_ids[mask_position + 1:]
139 |
140 | pad_len = max_seq_length - len(input_ids)
141 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
142 | labels = labels + [tokenizer.pad_token_id] * pad_len
143 | if args.ignore_pad_token_for_loss:
144 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]
145 | self.images.append(image)
146 | self.input_ids.append(input_ids)
147 | self.labels.append(labels)
148 | self.pre_image = pre_image
149 |
150 | def __len__(self):
151 | return len(self.images)
152 |
153 | def __getitem__(self, idx):
154 | return {
155 | "image": self.images[idx],
156 | "input_ids": self.input_ids[idx],
157 | "labels": self.labels[idx],
158 | "pre_image": self.pre_image
159 | }
160 |
161 |
162 | def create_dataset_function(path, args):
163 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True)
164 | image_processor = BlipImageEvalProcessor(224)
165 |
166 | dataset = FewShotDataset(path, image_processor, tokenizer, args)
167 | return dataset
168 |
169 |
170 | if __name__ == '__main__':
171 | py_parser = argparse.ArgumentParser(add_help=False)
172 | py_parser.add_argument('--max_source_length', type=int)
173 | py_parser.add_argument('--max_target_length', type=int)
174 | py_parser.add_argument('--ignore_pad_token_for_loss', type=bool, default=True)
175 | py_parser.add_argument('--source_prefix', type=str, default="")
176 | py_parser = FineTuneVisualGLMModel.add_model_specific_args(py_parser)
177 | known, args_list = py_parser.parse_known_args()
178 | args = get_args(args_list)
179 | args = argparse.Namespace(**vars(args), **vars(known))
180 | model_type = './checkpoints'
181 | model, args = FineTuneVisualGLMModel.from_pretrained(model_type, args)
182 | tokenizer = get_tokenizer(args)
183 | label_pad_token_id = -100
184 |
185 |
186 | def data_collator(examples):
187 | for example in examples:
188 | example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long)
189 | example['labels'] = torch.tensor(example['labels'], dtype=torch.long)
190 | ret = {
191 | 'input_ids': torch.stack([example['input_ids'] for example in examples]),
192 | 'labels': torch.stack([example['labels'] for example in examples]),
193 | 'image': torch.stack([example['image'] for example in examples]),
194 | 'pre_image': example['pre_image']
195 | }
196 | return ret
197 |
198 |
199 | training_main(args, model_cls=model, forward_step_function=forward_step,
200 | create_dataset_function=create_dataset_function, collate_fn=data_collator)
--------------------------------------------------------------------------------
/finetune_visualglm.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | NUM_WORKERS=1
3 | NUM_GPUS_PER_WORKER=8
4 | MP_SIZE=1
5 |
6 | script_path=$(realpath $0)
7 | script_dir=$(dirname $script_path)
8 | main_dir=$(dirname $script_dir)
9 | MODEL_TYPE="visualglm-6b"
10 | MODEL_ARGS="--max_source_length 64 \
11 | --max_target_length 256 \
12 | --lora_rank 10\
13 | --pre_seq_len 4"
14 |
15 | # OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models"
16 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2"
17 | HOST_FILE_PATH="hostfile"
18 | HOST_FILE_PATH="hostfile_single"
19 |
20 | train_data="./fewshot-data/dataset.json"
21 | eval_data="./fewshot-data/dataset.json"
22 |
23 |
24 | gpt_options=" \
25 | --experiment-name finetune-$MODEL_TYPE \
26 | --model-parallel-size ${MP_SIZE} \
27 | --mode finetune \
28 | --train-iters 300 \
29 | --resume-dataloader \
30 | $MODEL_ARGS \
31 | --train-data ${train_data} \
32 | --valid-data ${eval_data} \
33 | --distributed-backend nccl \
34 | --lr-decay-style cosine \
35 | --warmup .02 \
36 | --checkpoint-activations \
37 | --save-interval 300 \
38 | --eval-interval 10000 \
39 | --save "./checkpoints" \
40 | --split 1 \
41 | --eval-iters 10 \
42 | --eval-batch-size 8 \
43 | --zero-stage 1 \
44 | --lr 0.0001 \
45 | --batch-size 20 \
46 | --skip-init \
47 | --fp16 \
48 | --use_lora
49 | "
50 |
51 |
52 |
53 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}"
54 | echo ${run_cmd}
55 | eval ${run_cmd}
56 |
57 | set +x
58 |
--------------------------------------------------------------------------------
/img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/img.png
--------------------------------------------------------------------------------
/lora_mixin.py:
--------------------------------------------------------------------------------
1 | """
2 | In this mixin, I use a different implementation than sat/model/finetune/lora.py
3 | I just use a fake linear layer to replace any model with lora mixin.
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | from sat.model.base_model import BaseMixin
9 | import math
10 | from sat.helpers import print_all
11 | from sat.model.transformer import RowParallelLinear, ColumnParallelLinear
12 |
13 |
14 | class HackLinear(nn.Linear):
15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
16 | error_msgs):
17 | if prefix + 'weight' in state_dict:
18 | self.weight.data.copy_(state_dict[prefix + 'weight'])
19 | if prefix + 'bias' in state_dict:
20 | self.bias.data.copy_(state_dict[prefix + 'bias'])
21 |
22 |
23 | class HackRowParallelLinear(RowParallelLinear):
24 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
25 | error_msgs):
26 | if prefix + 'weight' in state_dict:
27 | self.weight.data.copy_(state_dict[prefix + 'weight'])
28 | if prefix + 'bias' in state_dict:
29 | self.bias.data.copy_(state_dict[prefix + 'bias'])
30 |
31 |
32 | class HackColumnParallelLinear(ColumnParallelLinear):
33 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
34 | error_msgs):
35 | if prefix + 'weight' in state_dict:
36 | self.weight.data.copy_(state_dict[prefix + 'weight'])
37 | if prefix + 'bias' in state_dict:
38 | self.bias.data.copy_(state_dict[prefix + 'bias'])
39 |
40 |
41 | try:
42 | from bitsandbytes.nn import LinearNF4
43 |
44 |
45 | def copy_nested_list(src, dst):
46 | for i in range(len(dst)):
47 | if type(dst[i]) is torch.Tensor:
48 | dst[i].copy_(src[i])
49 | elif type(dst[i]) is list:
50 | copy_nested_list(src[i], dst[i])
51 | else:
52 | dst[i] = src[i]
53 |
54 |
55 | class HackLinearNF4(LinearNF4):
56 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
57 | error_msgs):
58 | if prefix + 'weight' in state_dict:
59 | self.weight.data.copy_(state_dict[prefix + 'weight'])
60 | if self.weight.data.dtype == torch.uint8:
61 | copy_nested_list(state_dict[prefix + 'quant_state'], self.weight.quant_state)
62 | if prefix + 'bias' in state_dict:
63 | self.bias.data.copy_(state_dict[prefix + 'bias'])
64 |
65 | def _save_to_state_dict(self, destination, prefix, keep_vars):
66 | super()._save_to_state_dict(destination, prefix, keep_vars)
67 | destination[prefix + 'quant_state'] = self.weight.quant_state
68 | except Exception as exception:
69 | print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING')
70 |
71 |
72 | class HackParameterList(nn.ParameterList):
73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
74 | error_msgs):
75 | for i in range(len(self)):
76 | if prefix + str(i) in state_dict:
77 | self[i].data.copy_(state_dict[prefix + str(i)])
78 |
79 |
80 | map_cls = {
81 | nn.Linear: (HackLinear, {}),
82 | ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}),
83 | RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True})
84 | }
85 |
86 |
87 | class LoraLinear(nn.Module):
88 | def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False,
89 | num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False):
90 | """
91 | You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order.
92 | If you use a different order like ChatGLM
93 | """
94 | super().__init__()
95 | if lora_dropout and lora_dropout > 0:
96 | self.lora_dropout = nn.Dropout(p=lora_dropout)
97 | else:
98 | self.lora_dropout = lambda x: x
99 | self.r = r
100 | self.lora_alpha = lora_alpha
101 | self.scaling = self.lora_alpha / self.r
102 | if qlora:
103 | try:
104 | self.original = HackLinearNF4(in_dim, out_dim)
105 | except:
106 | raise Exception(
107 | 'Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.')
108 | else:
109 | base_cls, kwargs = map_cls[original_cls]
110 | self.original = base_cls(in_dim, out_dim, **kwargs)
111 | self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)])
112 | self.matrix_B = HackParameterList(
113 | [nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)])
114 | for i in range(partition):
115 | nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5))
116 | nn.init.zeros_(self.matrix_B[i])
117 | self.head_first = head_first
118 | self.partition = partition
119 | if head_first:
120 | assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!"
121 | self.num_attention_heads = num_attention_heads
122 | self.hidden_size_per_attention_head = hidden_size_per_attention_head
123 |
124 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
125 | error_msgs):
126 | # This is not a perfect version, becuase it doesn't handle errors and unexpected keys.
127 | if prefix + 'weight' in state_dict:
128 | # load from normal Linear
129 | self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
130 | unexpected_keys, error_msgs)
131 | else:
132 | # load from LoraLinear
133 | super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
134 | error_msgs)
135 |
136 | def forward(self, x):
137 | mixed_raw_layer = self.original(x)
138 | lora_outputs = []
139 | for i in range(self.partition):
140 | lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling)
141 | if self.head_first:
142 | new_tensor_shape = lora_outputs[0].size()[:-1] + (
143 | self.num_attention_heads,
144 | self.hidden_size_per_attention_head,
145 | )
146 | for i in range(self.partition):
147 | lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape)
148 | mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size())
149 | else:
150 | mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1)
151 |
152 | return mixed_raw_layer
153 |
154 |
155 | def replace_linear_with_lora(lin, partition, r, *args, **kw_args):
156 | # not supported for linear without bias for now
157 | out_dim, in_dim = lin.weight.shape
158 | original_cls = type(lin)
159 | del lin
160 | return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args)
161 |
162 |
163 | def merge_linear_lora(lin):
164 | if lin.original.weight.data.dtype is not torch.uint8:
165 | weight = lin.original.weight
166 | out_dim, in_dim = weight.shape
167 | new_lin = nn.Linear(in_dim, out_dim)
168 | else:
169 | import bitsandbytes.functional as F
170 | weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to(
171 | lin.original.bias.data.dtype)
172 | out_dim, in_dim = weight.shape
173 | new_lin = HackLinearNF4(in_dim, out_dim)
174 | new_lin.bias.data = lin.original.bias.data
175 | new_qkv = []
176 | for i in range(lin.partition):
177 | new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling)
178 | if lin.head_first:
179 | ini_shape = new_qkv[0].shape
180 | new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv]
181 | new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition * ini_shape[1])
182 | else:
183 | new_qkv = torch.cat(new_qkv, -1)
184 | new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype)
185 | return new_lin.cuda() if torch.cuda.is_available() else new_lin
186 |
187 |
188 | class LoraMixin(BaseMixin):
189 | def __init__(self,
190 | layer_num,
191 | r: int = 0,
192 | lora_alpha: int = 1,
193 | lora_dropout: float = 0.,
194 | layer_range=None,
195 | head_first=False,
196 | num_attention_heads=None,
197 | hidden_size_per_attention_head=None,
198 | qlora=False,
199 | cross_attention=True):
200 | super().__init__()
201 | self.r = r
202 | self.lora_alpha = lora_alpha
203 | self.lora_dropout = lora_dropout
204 |
205 | if layer_range is None:
206 | layer_range = [i for i in range(layer_num)]
207 | self.layer_range = layer_range
208 |
209 | self.scaling = self.lora_alpha / self.r
210 | self.head_first = head_first
211 | self.num_attention_heads = num_attention_heads
212 | self.hidden_size_per_attention_head = hidden_size_per_attention_head
213 | self.qlora = qlora
214 | self.cross_attention = cross_attention
215 |
216 | def reinit(self, parent_model):
217 | for i in self.layer_range:
218 | print(f'replacing layer {i} attention with lora')
219 | parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora(
220 | parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout,
221 | qlora=self.qlora)
222 | parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora(
223 | parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha,
224 | self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads,
225 | hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora)
226 | if self.cross_attention and parent_model.transformer.layers[i].is_decoder:
227 | print(f'replacing layer {i} cross attention with lora')
228 | parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora(
229 | parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha,
230 | self.lora_dropout, qlora=self.qlora)
231 | parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora(
232 | parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha,
233 | self.lora_dropout, qlora=self.qlora)
234 | parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora(
235 | parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha,
236 | self.lora_dropout, qlora=self.qlora)
237 | if self.qlora:
238 | print('replacing chatglm linear layer with 4bit')
239 |
240 | def replace_linear_with_nf4(model, name=None, cache={}):
241 | if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear):
242 | out_dim, in_dim = model.weight.shape
243 | return HackLinearNF4(in_dim, out_dim)
244 | names = set()
245 | for name, child in model.named_children():
246 | if name not in names:
247 | if child in cache:
248 | new_child = cache[child]
249 | else:
250 | new_child = replace_linear_with_nf4(child, name=name, cache=cache)
251 | cache[child] = new_child
252 | setattr(model, name, new_child)
253 | names.add(name)
254 | flag = True
255 | while flag:
256 | flag = False
257 | for name, child in model.named_children():
258 | if name not in names:
259 | setattr(model, name, cache[child])
260 | names.add(name)
261 | flag = True
262 | return model
263 |
264 | replace_linear_with_nf4(parent_model.transformer, None, {})
265 |
266 | def merge_lora(self):
267 | for i in self.layer_range:
268 | print(f'merge layer {i} lora attention back to linear')
269 | self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense)
270 | self.transformer.layers[i].attention.query_key_value = merge_linear_lora(
271 | self.transformer.layers[i].attention.query_key_value)
272 | if self.transformer.layers[i].is_decoder:
273 | print(f'merge layer {i} lora cross attention back to linear')
274 | self.transformer.layers[i].cross_attention.dense = merge_linear_lora(
275 | self.transformer.layers[i].cross_attention.dense)
276 | self.transformer.layers[i].cross_attention.query = merge_linear_lora(
277 | self.transformer.layers[i].cross_attention.query)
278 | self.transformer.layers[i].cross_attention.key_value = merge_linear_lora(
279 | self.transformer.layers[i].cross_attention.key_value)
280 |
281 |
282 | if __name__ == '__main__':
283 | class Model(nn.Module):
284 | def __init__(self):
285 | super().__init__()
286 | self.child = nn.Linear(100, 200)
287 |
288 | def forward(self, x):
289 | return self.child(x)
290 |
291 |
292 | model = Model()
293 | torch.save(model.state_dict(), "linear.pt")
294 | x = torch.randn(2, 100)
295 | out1 = model(x)
296 | model.child = LoraLinear(100, 200, 10)
297 | model.load_state_dict(torch.load("linear.pt"), strict=False)
298 | out2 = model(x)
299 | torch.save(model.state_dict(), "lora.pt")
300 | ckpt = torch.load("lora.pt")
301 | breakpoint()
302 | model.load_state_dict(ckpt, strict=False)
303 | out3 = model(x)
304 | breakpoint()
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .chat import chat
2 | from .infer_util import *
3 | from .blip2 import BlipImageEvalProcessor
4 |
--------------------------------------------------------------------------------
/model/blip2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from sat.model import ViTModel, BaseModel
5 | from sat.model import BaseMixin
6 | from sat import AutoModel
7 | from copy import deepcopy
8 | from torchvision import transforms
9 | from torchvision.transforms.functional import InterpolationMode
10 |
11 | class LNFinalyMixin(BaseMixin):
12 | def __init__(self, hidden_size):
13 | super().__init__()
14 | self.ln_vision = nn.LayerNorm(hidden_size)
15 |
16 | def final_forward(self, logits, **kw_args):
17 | return self.ln_vision(logits)
18 |
19 | class EVAViT(ViTModel):
20 | def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
21 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs)
22 | self.del_mixin("cls")
23 | self.add_mixin("cls", LNFinalyMixin(args.hidden_size))
24 |
25 | def forward(self, image):
26 | batch_size = image.size(0)
27 | input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device)
28 | attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device)
29 | return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image)
30 |
31 | class QFormer(BaseModel):
32 | def __init__(self, args, transformer=None, parallel_output=True, **kwargs):
33 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, activation_func=nn.functional.gelu, **kwargs)
34 | self.transformer.position_embeddings = None
35 |
36 | def final_forward(self, logits, **kw_args):
37 | return logits
38 |
39 | def position_embedding_forward(self, position_ids, **kw_args):
40 | return None
41 |
42 | def forward(self, encoder_outputs):
43 | batch_size = encoder_outputs.size(0)
44 | input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, -1)
45 | attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
46 | cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device)
47 | return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask)
48 |
49 |
50 | class BLIP2(torch.nn.Module):
51 | def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs):
52 | super().__init__()
53 | if vit is not None:
54 | self.vit = vit
55 | else:
56 | self.vit = EVAViT(EVAViT.get_args(**eva_args))
57 | if qformer is not None:
58 | self.qformer = qformer
59 | else:
60 | self.qformer = QFormer(QFormer.get_args(**qformer_args))
61 |
62 | self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(self.qformer.parameters().__next__().dtype)
63 |
64 | def forward(self, image, **kwargs):
65 | enc = self.vit(image)[0]
66 | out = self.qformer(enc)[0]
67 | return self.glm_proj(out)
68 |
69 | class BlipImageBaseProcessor():
70 | def __init__(self, mean=None, std=None):
71 | if mean is None:
72 | mean = (0.48145466, 0.4578275, 0.40821073)
73 | if std is None:
74 | std = (0.26862954, 0.26130258, 0.27577711)
75 |
76 | self.normalize = transforms.Normalize(mean, std)
77 |
78 | class BlipImageEvalProcessor(BlipImageBaseProcessor):
79 | def __init__(self, image_size=384, mean=None, std=None):
80 | super().__init__(mean=mean, std=std)
81 |
82 | self.transform = transforms.Compose(
83 | [
84 | transforms.Resize(
85 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
86 | ),
87 | transforms.ToTensor(),
88 | self.normalize,
89 | ]
90 | )
91 |
92 | def __call__(self, item):
93 | return self.transform(item)
94 |
--------------------------------------------------------------------------------
/model/chat.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | '''
3 | @File : chat.py
4 | @Time : 2023/05/08 19:10:08
5 | @Author : Ming Ding
6 | @Contact : dm18@mails.tsinghua.edu.cn
7 | '''
8 |
9 | import os
10 | import sys
11 | import re
12 | from functools import partial
13 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any
14 | import requests
15 | from PIL import Image
16 | from io import BytesIO
17 |
18 | import torch
19 | from sat.generation.autoregressive_sampling import filling_sequence, BaseStrategy
20 |
21 | from .blip2 import BlipImageEvalProcessor
22 |
23 | def get_masks_and_position_ids_glm(seq, mask_position, context_length):
24 | '''GLM model, different from GPT.
25 | Args:
26 | seq: torch.IntTensor, [seq_len]
27 | mask_position: int, the position of the masked place.
28 | context_length: int, the length of context.
29 | Returns:
30 | tokens: torch.IntTensor, [1, seq_len]
31 | attention_mask: torch.FloatTensor, [1, seq_len, seq_len]
32 | position_ids: torch.IntTensor, [2, seq_len]
33 | '''
34 | tokens = seq.unsqueeze(0)
35 |
36 | attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
37 | attention_mask.tril_()
38 | attention_mask[..., :context_length] = 1
39 | attention_mask.unsqueeze_(1)
40 |
41 | # 2D position ids
42 | position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long)
43 | torch.arange(0, context_length, out=position_ids[0, :context_length])
44 | position_ids[0, context_length:] = mask_position
45 | torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:])
46 |
47 | position_ids = position_ids.unsqueeze(0)
48 | return tokens, attention_mask, position_ids
49 |
50 | def process_response(response):
51 | response = response.strip()
52 | response = response.replace("[[训练时间]]", "2023年")
53 | punkts = [
54 | [",", ","],
55 | ["!", "!"],
56 | [":", ":"],
57 | [";", ";"],
58 | ["\?", "?"],
59 | ]
60 | for item in punkts:
61 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
62 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
63 | return response
64 |
65 | def process_image(text, image=None):
66 | '''Process image in text.
67 | Args:
68 | text: str, text.
69 | image: Optional, image path / url / PIL image.
70 | '''
71 | image_position = text.rfind("
") + 5
72 | # extract path from
using re
73 | image_path = re.findall(r"
(.*?)", text)
74 | image_path = image_path[-1] if image_path[-1] else None
75 | if image_path is not None:
76 | assert image is None, "image and image_path cannot be both not None."
77 | text = text.replace(image_path, "")
78 | image_path = image_path.strip()
79 | # url
80 | if image_path.startswith("http"):
81 | response = requests.get(image_path, timeout=10)
82 | image = Image.open(BytesIO(response.content))
83 | # local path
84 | else:
85 | image = Image.open(image_path)
86 | if image is not None and isinstance(image, Image.Image):
87 | processor = BlipImageEvalProcessor(224)
88 | image = processor(image.convert('RGB'))
89 | image = image.unsqueeze(0)
90 | return text, image_position, image
91 |
92 |
93 | def chat(image_path, model, tokenizer,
94 | query: str, history: List[Tuple[str, str]] = None, image: Image = None,
95 | max_length: int = 1024, top_p=0.7, top_k=30, temperature=0.95, repetition_penalty=1.2,
96 | invalid_slices=[], english=False
97 | ):
98 | if not history:
99 | history = []
100 | if image_path:
101 | prompt = "
{}".format(image_path if image_path else "")
102 | else:
103 | prompt = "
"
104 | if english:
105 | for i, (old_query, response) in enumerate(history):
106 | prompt += "Q:{}\nA:{}\n".format(old_query, response)
107 | prompt += "Q:{}\nA:".format(query)
108 | else:
109 | for i, (old_query, response) in enumerate(history):
110 | prompt += "问:{}\n答:{}\n".format(old_query, response)
111 | prompt += "问:{}\n答:".format(query)
112 | # ---------------
113 | # tokenizer, this is an example of huggingface tokenizer.
114 | # input str, output['input_ids'] = tensor([[tokenized str, gmask, sop]])
115 | prompt, image_position, torch_image = process_image(prompt, image=image)
116 | if torch_image is not None:
117 | torch_image = torch_image.to(next(model.parameters()).dtype).to(next(model.parameters()).device)
118 | if image_position < 5: # no image
119 | inputs = tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0]
120 | pre_image = 0
121 | else:
122 | input0 = tokenizer.encode(prompt[:image_position], add_special_tokens=False)
123 | input1 = [tokenizer.pad_token_id] * model.image_length
124 | input2 = tokenizer.encode(prompt[image_position:], add_special_tokens=False)
125 | inputs = sum([input0, input1, input2], [])
126 | inputs = torch.tensor(tokenizer.build_inputs_with_special_tokens(inputs)).to(model.parameters().__next__().device)
127 | pre_image = len(input0)
128 | # ---------------
129 | # Next, we manually set the format to keep flexibility.
130 | mask_position = len(inputs) - 2
131 | context_length = len(inputs) - 1 # all before sop
132 | get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=context_length)
133 | seq = torch.cat(
134 | [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0
135 | )
136 | # ---------------
137 | # from sat.generation.sampling_strategies import BeamSearchStrategy
138 | # strategy = BeamSearchStrategy(num_beams, length_penalty=1., prefer_min_length=5, end_tokens=[tokenizer.eos_token_id], consider_end=True, no_repeat_ngram_size=5, stop_n_iter_unchanged=30, temperature=temperature, top_p=top_p, top_k=60, repetition_penalty=1.1)
139 | strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
140 | invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
141 | output = filling_sequence(
142 | model, seq,
143 | batch_size=1,
144 | get_masks_and_position_ids=get_func,
145 | strategy=strategy,
146 | pre_image=pre_image,
147 | image=torch_image,
148 | )[0] # drop memory
149 |
150 | # ---------------
151 | # port from inference_glm.py, more general than chat mode
152 | # clip -1s and fill back generated things into seq
153 | if type(output) is not list:
154 | output_list = output.tolist()
155 | else:
156 | output_list = output
157 | for i in range(len(output_list)):
158 | output = output_list[i]
159 | if type(output) is not list:
160 | output = output.tolist()
161 | try:
162 | unfinished = output.index(-1)
163 | except ValueError:
164 | unfinished = len(output)
165 | if output[unfinished - 1] == tokenizer.eos_token_id:
166 | unfinished -= 1
167 | bog = output.index(tokenizer.bos_token_id)
168 | output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog]
169 | # ---------------
170 |
171 | response = tokenizer.decode(output_list[0])
172 | sep = 'A:' if english else '答:'
173 | response = process_response(response).split(sep)[-1].strip()
174 | history = history + [(query, response)]
175 | return response, history, torch_image
176 |
--------------------------------------------------------------------------------
/model/infer_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from io import BytesIO
4 | import base64
5 | import re
6 | import argparse
7 | import torch
8 | from transformers import AutoTokenizer
9 | from sat.model.mixins import CachedAutoregressiveMixin
10 | from sat.quantization.kernels import quantize
11 | import hashlib
12 | from .visualglm import VisualGLMModel
13 |
14 | def get_infer_setting(gpu_device=0, quant=None):
15 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_device)
16 | args = argparse.Namespace(
17 | fp16=True,
18 | skip_init=True,
19 | device='cuda' if quant is None else 'cpu',
20 | )
21 | model, args = VisualGLMModel.from_pretrained('visualglm-6b', args)
22 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
23 | assert quant in [None, 4, 8]
24 | if quant is not None:
25 | quantize(model.transformer, quant)
26 | model.eval()
27 | model = model.cuda()
28 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
29 | return model, tokenizer
30 |
31 | def is_chinese(text):
32 | zh_pattern = re.compile(u'[\u4e00-\u9fa5]+')
33 | return zh_pattern.search(text)
34 |
35 | def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True):
36 | if not image_is_encoded:
37 | image = input_image_prompt
38 | else:
39 | decoded_image = base64.b64decode(input_image_prompt)
40 | image = Image.open(BytesIO(decoded_image))
41 |
42 | input_data = {'input_query': input_text, 'input_image': image, 'history': history, 'gen_kwargs': input_para}
43 | return input_data
44 |
45 |
46 | def process_image(image_encoded):
47 | decoded_image = base64.b64decode(image_encoded)
48 | image = Image.open(BytesIO(decoded_image))
49 | image_hash = hashlib.sha256(image.tobytes()).hexdigest()
50 | image_path = f'./examples/{image_hash}.png'
51 | if not os.path.isfile(image_path):
52 | image.save(image_path)
53 | return os.path.abspath(image_path)
--------------------------------------------------------------------------------
/model/visualglm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sat.model.official import ChatGLMModel
3 | from sat.model.base_model import BaseMixin
4 | from copy import deepcopy
5 | import json
6 | from .blip2 import BLIP2
7 |
8 | from sat.resources.urls import MODEL_URLS
9 | MODEL_URLS['visualglm-6b'] = 'https://cloud.tsinghua.edu.cn/f/348b98dffcc940b6a09d/?dl=1'
10 |
11 | class ImageMixin(BaseMixin):
12 | def __init__(self, args):
13 | super().__init__()
14 | self.args = deepcopy(args)
15 | self.model = BLIP2(args.eva_args, args.qformer_args)
16 |
17 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
18 | if kw_args["pre_image"] > input_ids.shape[1] or kw_args.get("image", None) is None:
19 | return self.transformer.word_embeddings(input_ids)
20 | image_emb = self.model(**kw_args)
21 | # the image is inserted after 问:
, override 32 pads
22 | pre_id, pads, post_id = torch.tensor_split(input_ids, [kw_args["pre_image"], kw_args["pre_image"]+self.args.image_length], dim=1)
23 | pre_txt_emb = self.transformer.word_embeddings(pre_id)
24 | post_txt_emb = self.transformer.word_embeddings(post_id)
25 | return torch.cat([pre_txt_emb, image_emb, post_txt_emb], dim=1)
26 |
27 | class VisualGLMModel(ChatGLMModel):
28 | def __init__(self, args, transformer=None, **kwargs):
29 | super().__init__(args, transformer=transformer, **kwargs)
30 | self.image_length = args.image_length
31 | self.add_mixin("eva", ImageMixin(args))
32 |
33 | @classmethod
34 | def add_model_specific_args(cls, parser):
35 | group = parser.add_argument_group('VisualGLM', 'VisualGLM Configurations')
36 | group.add_argument('--image_length', type=int, default=32)
37 | group.add_argument('--eva_args', type=json.loads, default={})
38 | group.add_argument('--qformer_args', type=json.loads, default={})
39 | return super().add_model_specific_args(parser)
40 |
41 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModel
2 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True)
3 | model = AutoModel.from_pretrained("../visualglm-6b", trust_remote_code=True).half().cuda()
4 | image_path = "./fewshot-data/男性面相2.jpeg"
5 | print('输入:',image_path)
6 | response, history = model.chat(tokenizer, image_path, "从眼睛看这张照片可能的面相?", history=[])
7 | print("从眼睛看这张照片可能的面相:"+response)
8 | response, history = model.chat(tokenizer, image_path, "从鼻子描述这张图片人物可能的面相?", history=[])
9 | print("从鼻子描述这张图片人物可能的面相:"+response)
10 | response, history = model.chat(tokenizer, image_path, "从嘴巴描述这张图片人物可能的面相?", history=[])
11 | print("从嘴巴描述这张图片人物可能的面相:"+response)
12 | response, history = model.chat(tokenizer, image_path, "从天庭描述这张图片人物可能的面相?", history=[])
13 | print("从天庭描述这张图片人物可能的面相:"+response)
14 | response, history = model.chat(tokenizer, image_path, "这个人的精神状态怎么样", history=history)
15 | print('这个人的精神状态怎么样',response)
16 | # response, history = model.chat(tokenizer, image_path, "喜欢这张图片可能是什么样的年龄性别职业", history=history)
17 | # print('什么样的人会喜欢:',response)
18 |
--------------------------------------------------------------------------------
/predict_lora.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 |
3 | import os
4 | import sys
5 | import torch
6 | import argparse
7 | from transformers import AutoTokenizer
8 | from sat.model.mixins import CachedAutoregressiveMixin
9 | from sat.quantization.kernels import quantize
10 |
11 | from model import VisualGLMModel, chat
12 | from finetune_visualglm import FineTuneVisualGLMModel
13 | from sat.model import AutoModel
14 |
15 |
16 | def main():
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
19 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
20 | parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling')
21 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
22 | parser.add_argument("--english", action='store_true', help='only output English')
23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
24 | parser.add_argument("--from_pretrained", type=str, default="checkpoints/finetune-visualglm-6b-06-06-00-05/",
25 | help='pretrained ckpt')
26 | args = parser.parse_args()
27 |
28 | # load model
29 | model, model_args = AutoModel.from_pretrained(
30 | args.from_pretrained,
31 | args=argparse.Namespace(
32 | fp16=True,
33 | skip_init=True,
34 | use_gpu_initialization=True,
35 | device='cuda'
36 | ))
37 | model = model.to(torch.float32)
38 | model = model.eval()
39 |
40 | if args.quant:
41 | quantize(model.transformer, args.quant)
42 |
43 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
44 | tokenizer = AutoTokenizer.from_pretrained("../chatglm-6b", trust_remote_code=True)
45 | with torch.no_grad():
46 | history = None
47 | cache_image = None
48 | image_path = './fewshot-data/虎眼.jpeg'
49 | query = '从鼻子描述这张图片人物可能的面相?'
50 | response, history, cache_image = chat(
51 | image_path,
52 | model,
53 | tokenizer,
54 | query,
55 | history=history,
56 | image=cache_image,
57 | max_length=args.max_length,
58 | top_p=args.top_p,
59 | temperature=args.temperature,
60 | top_k=args.top_k,
61 | english=args.english,
62 | invalid_slices=[slice(63823, 130000)] if args.english else []
63 | )
64 | sep = 'A:' if args.english else '答:'
65 | print(query + ': ' + response.split(sep)[-1].strip())
66 | image_path = None
67 |
68 | query = '从眼睛看这张照片可能的面相?'
69 | response, history, cache_image = chat(
70 | image_path,
71 | model,
72 | tokenizer,
73 | query,
74 | history=history,
75 | image=cache_image,
76 | max_length=args.max_length,
77 | top_p=args.top_p,
78 | temperature=args.temperature,
79 | top_k=args.top_k,
80 | english=args.english,
81 | invalid_slices=[slice(63823, 130000)] if args.english else []
82 | )
83 | sep = 'A:' if args.english else '答:'
84 | print(query + ': ' + response.split(sep)[-1].strip())
85 |
86 | query = '从嘴巴描述这张图片人物可能的面相?'
87 | response, history, cache_image = chat(
88 | image_path,
89 | model,
90 | tokenizer,
91 | query,
92 | history=history,
93 | image=cache_image,
94 | max_length=args.max_length,
95 | top_p=args.top_p,
96 | temperature=args.temperature,
97 | top_k=args.top_k,
98 | english=args.english,
99 | invalid_slices=[slice(63823, 130000)] if args.english else []
100 | )
101 | sep = 'A:' if args.english else '答:'
102 | print(query + ': ' + response.split(sep)[-1].strip())
103 |
104 | query = '从天庭描述这张图片人物可能的面相?'
105 | response, history, cache_image = chat(
106 | image_path,
107 | model,
108 | tokenizer,
109 | query,
110 | history=history,
111 | image=cache_image,
112 | max_length=args.max_length,
113 | top_p=args.top_p,
114 | temperature=args.temperature,
115 | top_k=args.top_k,
116 | english=args.english,
117 | invalid_slices=[slice(63823, 130000)] if args.english else []
118 | )
119 | sep = 'A:' if args.english else '答:'
120 | print(query + ': ' + response.split(sep)[-1].strip())
121 |
122 |
123 | if __name__ == "__main__":
124 | main()
125 |
126 | # python predict_lora.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | SwissArmyTransformer>=0.3.6
2 | torch>1.10.0
3 | torchvision
4 | transformers>=4.27.1
5 | mdtex2html
6 | gradio
--------------------------------------------------------------------------------
/requirements_wo_ds.txt:
--------------------------------------------------------------------------------
1 | torch>1.10.0
2 | torchvision
3 | transformers>=4.27.1
4 | mdtex2html
5 | gradio
6 | sentencepiece
7 | tensorboardX
8 | datasets
9 | cpm_kernels
10 | einops
--------------------------------------------------------------------------------
/web_demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import gradio as gr
4 | from PIL import Image
5 | import os
6 | import json
7 | from model import is_chinese, get_infer_setting, generate_input, chat
8 |
9 | def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True):
10 | input_para = {
11 | "max_length": 2048,
12 | "min_length": 50,
13 | "temperature": 0.8,
14 | "top_p": 0.4,
15 | "top_k": 100,
16 | "repetition_penalty": 1.2
17 | }
18 | input_para.update(request_data)
19 |
20 | input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False)
21 | input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs']
22 | answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \
23 | max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \
24 | top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh)
25 | return answer
26 |
27 |
28 | def request_model(input_text, temperature, top_p, image_prompt, result_previous):
29 | result_text = [(ele[0], ele[1]) for ele in result_previous]
30 | for i in range(len(result_text)-1, -1, -1):
31 | if result_text[i][0] == "" or result_text[i][1] == "":
32 | del result_text[i]
33 | print(f"history {result_text}")
34 |
35 | is_zh = is_chinese(input_text)
36 | if image_prompt is None:
37 | if is_zh:
38 | result_text.append((input_text, '图片为空!请上传图片并重试。'))
39 | else:
40 | result_text.append((input_text, 'Image empty! Please upload a image and retry.'))
41 | return input_text, result_text
42 | elif input_text == "":
43 | result_text.append((input_text, 'Text empty! Please enter text and retry.'))
44 | return "", result_text
45 |
46 | request_para = {"temperature": temperature, "top_p": top_p}
47 | image = Image.open(image_prompt)
48 | try:
49 | answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh)
50 | except Exception as e:
51 | print(f"error: {e}")
52 | if is_zh:
53 | result_text.append((input_text, '超时!请稍等几分钟再重试。'))
54 | else:
55 | result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
56 | return "", result_text
57 |
58 | result_text.append((input_text, answer))
59 | print(result_text)
60 | return "", result_text
61 |
62 |
63 | DESCRIPTION = '''# VisualGLM'''
64 |
65 | MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
66 | MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。'
67 |
68 | NOTES = 'This app is adapted from https://github.com/THUDM/VisualGLM-6B. It would be recommended to check out the repo if you want to see the detail of our model and training process.'
69 |
70 |
71 | def clear_fn(value):
72 | return "", [("", "Hi, What do you want to know about this image?")], None
73 |
74 | def clear_fn2(value):
75 | return [("", "Hi, What do you want to know about this image?")]
76 |
77 |
78 | def main(args):
79 | gr.close_all()
80 | global model, tokenizer
81 | model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant)
82 |
83 | with gr.Blocks(css='style.css') as demo:
84 | gr.Markdown(DESCRIPTION)
85 | with gr.Row():
86 | with gr.Column(scale=4.5):
87 | with gr.Group():
88 | input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
89 | with gr.Row():
90 | run_button = gr.Button('Generate')
91 | clear_button = gr.Button('Clear')
92 |
93 | image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
94 | with gr.Row():
95 | temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
96 | top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
97 | with gr.Group():
98 | with gr.Row():
99 | maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1)
100 | with gr.Column(scale=5.5):
101 | result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550)
102 |
103 | gr.Markdown(NOTES)
104 |
105 | print(gr.__version__)
106 | run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
107 | outputs=[input_text, result_text])
108 | input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text],
109 | outputs=[input_text, result_text])
110 | clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
111 | image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
112 | image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
113 |
114 | print(gr.__version__)
115 |
116 |
117 | demo.queue(concurrency_count=10)
118 | demo.launch(share=args.share)
119 |
120 |
121 | if __name__ == '__main__':
122 | import argparse
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None)
125 | parser.add_argument("--share", action="store_true")
126 | args = parser.parse_args()
127 |
128 | main(args)
--------------------------------------------------------------------------------
/web_demo_hf.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | import gradio as gr
3 | import mdtex2html
4 |
5 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True)
6 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda()
7 | model = model.eval()
8 |
9 | """Override Chatbot.postprocess"""
10 |
11 |
12 | def postprocess(self, y):
13 | if y is None:
14 | return []
15 | for i, (message, response) in enumerate(y):
16 | y[i] = (
17 | None if message is None else mdtex2html.convert((message)),
18 | None if response is None else mdtex2html.convert(response),
19 | )
20 | return y
21 |
22 |
23 | gr.Chatbot.postprocess = postprocess
24 |
25 |
26 | def parse_text(text):
27 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
28 | lines = text.split("\n")
29 | lines = [line for line in lines if line != ""]
30 | count = 0
31 | for i, line in enumerate(lines):
32 | if "```" in line:
33 | count += 1
34 | items = line.split('`')
35 | if count % 2 == 1:
36 | lines[i] = f'
'
37 | else:
38 | lines[i] = f'
'
39 | else:
40 | if i > 0:
41 | if count % 2 == 1:
42 | line = line.replace("`", "\`")
43 | line = line.replace("<", "<")
44 | line = line.replace(">", ">")
45 | line = line.replace(" ", " ")
46 | line = line.replace("*", "*")
47 | line = line.replace("_", "_")
48 | line = line.replace("-", "-")
49 | line = line.replace(".", ".")
50 | line = line.replace("!", "!")
51 | line = line.replace("(", "(")
52 | line = line.replace(")", ")")
53 | line = line.replace("$", "$")
54 | lines[i] = "