├── LICENSE.txt ├── MODEL_LICENSE.txt ├── README.md ├── __init__.py ├── api.py ├── api_hf.py ├── checkpoints ├── 300 │ └── 需要下载pt文件.txt ├── latest └── model_config.json ├── cli_demo.py ├── cli_demo_hf.py ├── examples ├── 1.jpeg ├── 2.jpeg ├── 3.jpeg ├── chat_example1.png ├── chat_example2.png ├── chat_example3.png ├── example_inputs.jsonl ├── thu.png └── web_demo.png ├── fewshot-data ├── 2p.png ├── dataset.json ├── ghost.jpg ├── justice.png ├── katong.png ├── kobe.png ├── man.jpg ├── meme.png ├── music.png ├── panda.jpg ├── passport.png ├── pattern.png ├── pig.png ├── push.png ├── rou.png ├── rub.png ├── tianye.png ├── titan.png ├── tower.png ├── traf.png └── woman.png ├── finetune.py ├── finetune_visualglm.py ├── finetune_visualglm.sh ├── img.png ├── lora_mixin.py ├── model ├── __init__.py ├── blip2.py ├── chat.py ├── infer_util.py └── visualglm.py ├── predict.py ├── predict_lora.py ├── requirements.txt ├── requirements_wo_ds.txt ├── web_demo.py └── web_demo_hf.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright Zhengxiao Du 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MODEL_LICENSE.txt: -------------------------------------------------------------------------------- 1 | The GLM-130B License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means the VisualGLM-6B Model Team that distributes its Software. 6 | 7 | “Software” means the VisualGLM-6B model parameters made available under this license. 8 | 9 | 2. License Grant 10 | 11 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software solely for your non-commercial research purposes. 12 | 13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 14 | 15 | 3. Restriction 16 | 17 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any commercial, military, or illegal purposes. 18 | 19 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 20 | 21 | 4. Disclaimer 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 24 | 25 | 5. Limitation of Liability 26 | 27 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 28 | 29 | 6. Dispute Resolution 30 | 31 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 32 | 33 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Visualglm-image-to-text 2 | ##介绍 3 | 4 | 使用了清华的Visualglm语言模型进行lora finetune,做了个简单的面相预测 5 | 6 | 补充了训练需要的一些文件,补全训练缺少latest等文件问题 7 | 8 | 补充一些训练和预测代码。 9 | 10 | ## 样例 11 | 对面相进行预测 12 | 13 | image 14 | 15 | 16 | image 17 | 18 | 19 | 20 | ## 使用 21 | 22 | ### 模型推理 23 | 24 | 使用pip安装依赖 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 尽量使用标准PyPI源以下载较新的sat包,TUNA源等可能同步较慢。`pip install -i https://pypi.org/simple -r requirements.txt`。 29 | 此时默认会安装`deepspeed`库(支持`sat`库训练),此库对于模型推理并非必要,同时部分Windows环境安装此库时会遇到问题。如果想绕过`deepspeed`安装,我们可以将命令改为 30 | ``` 31 | pip install -r requirements_wo_ds.txt 32 | pip install --no-deps "SwissArmyTransformer>=0.3.6" 33 | ``` 34 | 35 | ### 模型微调 36 | 1,需要在checkpints/300目录下下载mp_rank_00_model_states.pt文件,获取途径如下 37 | 38 | wget https://huggingface.co/wangrongsheng/XrayGLM-300/resolve/main/300/mp_rank_00_model_states.pt 39 | 40 | 2,visualglm-6b 文件下载路径如下 41 | 42 | [https://huggingface.co/THUDM/visualglm-6b/tree/main] 43 | 44 | 然后执行 45 | bash finetune_visualglm.sh 46 | 47 | ### 模型预测 48 | 使用原模型:python predict.py 49 | 50 | 使用finetune后模型:python predict_lora.py 51 | 52 | ##注意文件里路径都需要注意做相应调整,改到自己目录下。 53 | 54 | # Visualglm-image-to-text -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/__init__.py -------------------------------------------------------------------------------- /api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import uvicorn 4 | from fastapi import FastAPI, Request 5 | from model import is_chinese, get_infer_setting, generate_input, chat 6 | import datetime 7 | 8 | gpu_number = 0 9 | model, tokenizer = get_infer_setting(gpu_device=gpu_number) 10 | 11 | app = FastAPI() 12 | @app.post('/') 13 | async def visual_glm(request: Request): 14 | json_post_raw = await request.json() 15 | print("Start to process request") 16 | 17 | json_post = json.dumps(json_post_raw) 18 | request_data = json.loads(json_post) 19 | input_text, input_image_encoded, history = request_data['text'], request_data['image'], request_data['history'] 20 | input_para = { 21 | "max_length": 2048, 22 | "min_length": 50, 23 | "temperature": 0.8, 24 | "top_p": 0.4, 25 | "top_k": 100, 26 | "repetition_penalty": 1.2 27 | } 28 | input_para.update(request_data) 29 | 30 | is_zh = is_chinese(input_text) 31 | input_data = generate_input(input_text, input_image_encoded, history, input_para) 32 | input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] 33 | answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \ 34 | max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ 35 | top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) 36 | 37 | now = datetime.datetime.now() 38 | time = now.strftime("%Y-%m-%d %H:%M:%S") 39 | response = { 40 | "result": answer, 41 | "history": history, 42 | "status": 200, 43 | "time": time 44 | } 45 | return response 46 | 47 | 48 | if __name__ == '__main__': 49 | uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) -------------------------------------------------------------------------------- /api_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from transformers import AutoTokenizer, AutoModel 4 | import uvicorn 5 | from fastapi import FastAPI, Request 6 | import datetime 7 | from model import process_image 8 | 9 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) 10 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() 11 | 12 | 13 | app = FastAPI() 14 | @app.post('/') 15 | async def visual_glm(request: Request): 16 | json_post_raw = await request.json() 17 | print("Start to process request") 18 | 19 | json_post = json.dumps(json_post_raw) 20 | request_data = json.loads(json_post) 21 | 22 | history = request_data.get("history") 23 | image_encoded = request_data.get("image") 24 | query = request_data.get("text") 25 | image_path = process_image(image_encoded) 26 | 27 | result = model.stream_chat(tokenizer, image_path, query, history=history) 28 | last_result = None 29 | for value in result: 30 | last_result = value 31 | answer = last_result[0] 32 | 33 | if os.path.isfile(image_path): 34 | os.remove(image_path) 35 | now = datetime.datetime.now() 36 | time = now.strftime("%Y-%m-%d %H:%M:%S") 37 | response = { 38 | "result": answer, 39 | "history": history, 40 | "status": 200, 41 | "time": time 42 | } 43 | return response 44 | 45 | 46 | if __name__ == "__main__": 47 | uvicorn.run(app, host='0.0.0.0', port=8080, workers=1) -------------------------------------------------------------------------------- /checkpoints/300/需要下载pt文件.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/checkpoints/300/需要下载pt文件.txt -------------------------------------------------------------------------------- /checkpoints/latest: -------------------------------------------------------------------------------- 1 | 300 -------------------------------------------------------------------------------- /checkpoints/model_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_class": "FineTuneVisualGLMModel", 3 | "tokenizer_type": "THUDM/chatglm-6b", 4 | "num_layers": 28, 5 | "hidden_size": 4096, 6 | "num_attention_heads": 32, 7 | "vocab_size": 130528, 8 | "layernorm_order": "post", 9 | "model_parallel_size": 1, 10 | "max_sequence_length": 2048, 11 | "pre_seq_len": 128, 12 | "lora_rank": 10, 13 | "use_ptuning": false, 14 | "use_lora": true, 15 | "image_length": 32, 16 | "eva_args": { 17 | "num_layers": 39, 18 | "hidden_size": 1408, 19 | "num_attention_heads": 16, 20 | "vocab_size": 1, 21 | "layernorm_order": "pre", 22 | "model_parallel_size": 1, 23 | "max_sequence_length": 257, 24 | "inner_hidden_size": 6144, 25 | "use_final_layernorm": false, 26 | "layernorm_epsilon": 1e-06, 27 | "image_size": [ 28 | 224, 29 | 224 30 | ], 31 | "pre_len": 1, 32 | "post_len": 0, 33 | "in_channels": 3, 34 | "num_classes": 0, 35 | "patch_size": 14 36 | }, 37 | "qformer_args": { 38 | "num_layers": 12, 39 | "hidden_size": 768, 40 | "num_attention_heads": 12, 41 | "vocab_size": 32, 42 | "layernorm_order": "post", 43 | "model_parallel_size": 1, 44 | "max_sequence_length": 0, 45 | "is_decoder": [ 46 | true, 47 | false, 48 | true, 49 | false, 50 | true, 51 | false, 52 | true, 53 | false, 54 | true, 55 | false, 56 | true, 57 | false 58 | ], 59 | "cross_attn_hidden_size": 1408, 60 | "layernorm_epsilon": 1e-12 61 | }, 62 | "bos_token_id": 130004, 63 | "mask_token_id": 130000, 64 | "gmask_token_id": 130001, 65 | "image_size": [ 66 | 224, 67 | 224 68 | ], 69 | "pre_len": 1, 70 | "post_len": 0, 71 | "in_channels": 3, 72 | "patch_size": 14 73 | } -------------------------------------------------------------------------------- /cli_demo.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import torch 6 | import argparse 7 | from transformers import AutoTokenizer 8 | from sat.model.mixins import CachedAutoregressiveMixin 9 | from sat.quantization.kernels import quantize 10 | 11 | from model import VisualGLMModel, chat 12 | from finetune_visualglm import FineTuneVisualGLMModel 13 | from sat.model import AutoModel 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') 19 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') 20 | parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling') 21 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') 22 | parser.add_argument("--english", action='store_true', help='only output English') 23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') 24 | parser.add_argument("--from_pretrained", type=str, default="visualglm-6b", help='pretrained ckpt') 25 | parser.add_argument("--prompt_zh", type=str, default="描述这张图片。", help='Chinese prompt for the first round') 26 | parser.add_argument("--prompt_en", type=str, default="Describe the image.", help='English prompt for the first round') 27 | args = parser.parse_args() 28 | 29 | # load model 30 | model, model_args = AutoModel.from_pretrained( 31 | args.from_pretrained, 32 | args=argparse.Namespace( 33 | fp16=True, 34 | skip_init=True, 35 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, 36 | device='cuda' if (torch.cuda.is_available() and args.quant is None) else 'cpu', 37 | )) 38 | model = model.eval() 39 | 40 | if args.quant: 41 | quantize(model.transformer, args.quant) 42 | 43 | if torch.cuda.is_available(): 44 | model = model.cuda() 45 | 46 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 47 | 48 | tokenizer = AutoTokenizer.from_pretrained("../chatglm-6b", trust_remote_code=True) 49 | if not args.english: 50 | print('欢迎使用 VisualGLM-6B 模型,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序') 51 | else: 52 | print('Welcome to VisualGLM-6B model. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.') 53 | with torch.no_grad(): 54 | while True: 55 | history = None 56 | cache_image = None 57 | if not args.english: 58 | image_path = input("请输入图像路径或URL(回车进入纯文本对话): ") 59 | else: 60 | image_path = input("Please enter the image path or URL (press Enter for plain text conversation): ") 61 | 62 | if image_path == 'stop': 63 | break 64 | if len(image_path) > 0: 65 | query = args.prompt_en if args.english else args.prompt_zh 66 | else: 67 | if not args.english: 68 | query = input("用户:") 69 | else: 70 | query = input("User: ") 71 | while True: 72 | if query == "clear": 73 | break 74 | if query == "stop": 75 | sys.exit(0) 76 | try: 77 | response, history, cache_image = chat( 78 | image_path, 79 | model, 80 | tokenizer, 81 | query, 82 | history=history, 83 | image=cache_image, 84 | max_length=args.max_length, 85 | top_p=args.top_p, 86 | temperature=args.temperature, 87 | top_k=args.top_k, 88 | english=args.english, 89 | invalid_slices=[slice(63823, 130000)] if args.english else [] 90 | ) 91 | except Exception as e: 92 | print(e) 93 | break 94 | sep = 'A:' if args.english else '答:' 95 | print("VisualGLM-6B:"+response.split(sep)[-1].strip()) 96 | image_path = None 97 | if not args.english: 98 | query = input("用户:") 99 | else: 100 | query = input("User: ") 101 | 102 | 103 | if __name__ == "__main__": 104 | main() -------------------------------------------------------------------------------- /cli_demo_hf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import signal 4 | from transformers import AutoTokenizer, AutoModel 5 | 6 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) 7 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() 8 | model = model.eval() 9 | 10 | os_name = platform.system() 11 | clear_command = 'cls' if os_name == 'Windows' else 'clear' 12 | stop_stream = False 13 | 14 | 15 | def build_prompt(history, prefix): 16 | prompt = prefix 17 | for query, response in history: 18 | prompt += f"\n\n用户:{query}" 19 | prompt += f"\n\nVisualGLM-6B:{response}" 20 | return prompt 21 | 22 | 23 | def signal_handler(signal, frame): 24 | global stop_stream 25 | stop_stream = True 26 | 27 | 28 | def main(): 29 | global stop_stream 30 | while True: 31 | history = [] 32 | prefix = "欢迎使用 VisualGLM-6B 模型,输入图片路径和内容即可进行对话,clear 清空对话历史,stop 终止程序" 33 | print(prefix) 34 | image_path = input("\n请输入图片路径:") 35 | if image_path == "stop": 36 | break 37 | prefix = prefix + "\n" + image_path 38 | query = "描述这张图片。" 39 | while True: 40 | count = 0 41 | for response, history in model.stream_chat(tokenizer, image_path, query, history=history): 42 | if stop_stream: 43 | stop_stream = False 44 | break 45 | else: 46 | count += 1 47 | if count % 8 == 0: 48 | os.system(clear_command) 49 | print(build_prompt(history, prefix), flush=True) 50 | signal.signal(signal.SIGINT, signal_handler) 51 | os.system(clear_command) 52 | print(build_prompt(history, prefix), flush=True) 53 | query = input("\n用户:") 54 | if query.strip() == "clear": 55 | break 56 | if query.strip() == "stop": 57 | stop_stream = True 58 | exit(0) 59 | # if query.strip() == "clear": 60 | # history = [] 61 | # os.system(clear_command) 62 | # print(prefix) 63 | # continue 64 | 65 | 66 | if __name__ == "__main__": 67 | main() -------------------------------------------------------------------------------- /examples/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/1.jpeg -------------------------------------------------------------------------------- /examples/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/2.jpeg -------------------------------------------------------------------------------- /examples/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/3.jpeg -------------------------------------------------------------------------------- /examples/chat_example1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example1.png -------------------------------------------------------------------------------- /examples/chat_example2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example2.png -------------------------------------------------------------------------------- /examples/chat_example3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/chat_example3.png -------------------------------------------------------------------------------- /examples/example_inputs.jsonl: -------------------------------------------------------------------------------- 1 | {"id":1, "text": "描述一下这个场景", "image": "examples/1.jpeg"} 2 | {"id":2, "text": "这是什么东西", "image": "examples/2.jpeg"} 3 | {"id":3, "text": "这张图片描述了什么", "image": "examples/3.jpeg"} -------------------------------------------------------------------------------- /examples/thu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/thu.png -------------------------------------------------------------------------------- /examples/web_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/examples/web_demo.png -------------------------------------------------------------------------------- /fewshot-data/2p.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/2p.png -------------------------------------------------------------------------------- /fewshot-data/dataset.json: -------------------------------------------------------------------------------- 1 | [ 2 | {"img": "fewshot-data/2p.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蒙蒙细雨。"}, 3 | {"img": "fewshot-data/pig.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是是虚化的。"}, 4 | {"img": "fewshot-data/meme.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蓝色的木质地板。"}, 5 | {"img": "fewshot-data/passport.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是棕黄色木质桌子。"}, 6 | {"img": "fewshot-data/tower.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是黄昏的天空、云彩和繁华的城市高楼。"}, 7 | {"img": "fewshot-data/rub.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是太阳、大树、蓝天白云。"}, 8 | {"img": "fewshot-data/push.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是蓝天和沙漠。"}, 9 | {"img": "fewshot-data/traf.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是城市街道。"}, 10 | {"img": "fewshot-data/music.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一个音乐混音器。"}, 11 | {"img": "fewshot-data/pattern.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是小区的楼房和街道。"}, 12 | {"img": "fewshot-data/rou.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是大理石桌子和一个盘子。"}, 13 | {"img": "fewshot-data/katong.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是绿色的草地。"}, 14 | {"img": "fewshot-data/man.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是城市的街道和高楼。"}, 15 | {"img": "fewshot-data/kobe.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是虚化的观众席。"}, 16 | {"img": "fewshot-data/panda.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是纯白的。"}, 17 | {"img": "fewshot-data/titan.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一座雕像。"}, 18 | {"img": "fewshot-data/woman.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是纯蓝的。"}, 19 | {"img": "fewshot-data/ghost.jpg", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是一个房间。"}, 20 | {"img": "fewshot-data/justice.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是天空和阳光。"}, 21 | {"img": "fewshot-data/tianye.png", "prompt": "这张图片的背景里有什么内容?", "label": "这张图片的背景是金黄的田野。"} 22 | ] -------------------------------------------------------------------------------- /fewshot-data/ghost.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/ghost.jpg -------------------------------------------------------------------------------- /fewshot-data/justice.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/justice.png -------------------------------------------------------------------------------- /fewshot-data/katong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/katong.png -------------------------------------------------------------------------------- /fewshot-data/kobe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/kobe.png -------------------------------------------------------------------------------- /fewshot-data/man.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/man.jpg -------------------------------------------------------------------------------- /fewshot-data/meme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/meme.png -------------------------------------------------------------------------------- /fewshot-data/music.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/music.png -------------------------------------------------------------------------------- /fewshot-data/panda.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/panda.jpg -------------------------------------------------------------------------------- /fewshot-data/passport.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/passport.png -------------------------------------------------------------------------------- /fewshot-data/pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/pattern.png -------------------------------------------------------------------------------- /fewshot-data/pig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/pig.png -------------------------------------------------------------------------------- /fewshot-data/push.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/push.png -------------------------------------------------------------------------------- /fewshot-data/rou.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/rou.png -------------------------------------------------------------------------------- /fewshot-data/rub.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/rub.png -------------------------------------------------------------------------------- /fewshot-data/tianye.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/tianye.png -------------------------------------------------------------------------------- /fewshot-data/titan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/titan.png -------------------------------------------------------------------------------- /fewshot-data/tower.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/tower.png -------------------------------------------------------------------------------- /fewshot-data/traf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/traf.png -------------------------------------------------------------------------------- /fewshot-data/woman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/fewshot-data/woman.png -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer 3 | from model import chat 4 | from model import VisualGLMModel 5 | from sat.model.mixins import CachedAutoregressiveMixin 6 | 7 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True) 8 | model, model_args = VisualGLMModel.from_pretrained("../visualglm-6b",args=argparse.Namespace(fp16=True, skip_init=True)) 9 | 10 | model.add_mixin('auto_regressive', CachedAutoregressiveMixin) 11 | image_path='./fewshot-data/龙眼.jpeg' 12 | promote='描述这张图片' 13 | responses,history,cache_image= chat(image_path,model,tokenizer,promote,history=[]) 14 | print(responses) -------------------------------------------------------------------------------- /finetune_visualglm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | from sat import mpu, get_args, get_tokenizer 6 | from sat.training.deepspeed_training import training_main 7 | from model import VisualGLMModel 8 | from sat.model.finetune import PTuningV2Mixin 9 | from sat.model.finetune.lora_mixin import LoraMixin 10 | from transformers import AutoTokenizer, AutoModel 11 | from model import VisualGLMModel 12 | 13 | 14 | class FineTuneVisualGLMModel(VisualGLMModel): 15 | def __init__(self, args, transformer=None, parallel_output=True, **kw_args): 16 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kw_args) 17 | if args.use_ptuning: 18 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, 19 | args.num_attention_heads, args.pre_seq_len)) 20 | if args.use_lora: 21 | # If you use lora on other "normal" Transformer, just use it with head_first=False (by default) 22 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, head_first=True, 23 | num_attention_heads=args.num_attention_heads, 24 | hidden_size_per_attention_head=args.hidden_size // args.num_attention_heads, 25 | layer_range=list(range(0, 28, 14))), reinit=True) 26 | # self.get_mixin("eva").model.glm_proj = replace_linear_with_lora(self.get_mixin("eva").model.glm_proj, LoraLinear, args.lora_rank) 27 | self.args = args 28 | 29 | @classmethod 30 | def add_model_specific_args(cls, parser): 31 | group = parser.add_argument_group('VisualGLM-finetune', 'VisualGLM finetune Configurations') 32 | group.add_argument('--pre_seq_len', type=int, default=8) 33 | group.add_argument('--lora_rank', type=int, default=10) 34 | group.add_argument('--use_ptuning', action="store_true") 35 | group.add_argument('--use_lora', action="store_true") 36 | return super().add_model_specific_args(parser) 37 | 38 | def disable_untrainable_params(self): 39 | enable = [] 40 | if self.args.use_ptuning: 41 | enable.extend(['ptuning']) 42 | if self.args.use_lora: 43 | enable.extend(['matrix_A', 'matrix_B']) 44 | for n, p in self.named_parameters(): 45 | flag = False 46 | for e in enable: 47 | if e.lower() in n.lower(): 48 | flag = True 49 | break 50 | if not flag: 51 | p.requires_grad_(False) 52 | else: 53 | print(n) 54 | 55 | 56 | def get_batch(data_iterator, args, timers): 57 | # Items and their type. 58 | keys = ['input_ids', 'labels'] 59 | datatype = torch.int64 60 | 61 | # Broadcast data. 62 | timers('data loader').start() 63 | if data_iterator is not None: 64 | data = next(data_iterator) 65 | else: 66 | data = None 67 | timers('data loader').stop() 68 | data_b = mpu.broadcast_data(keys, data, datatype) 69 | data_i = mpu.broadcast_data(['image'], data, torch.float32) 70 | # Unpack. 71 | tokens = data_b['input_ids'].long() 72 | labels = data_b['labels'].long() 73 | img = data_i['image'] 74 | if args.fp16: 75 | img = img.half() 76 | 77 | return tokens, labels, img, data['pre_image'] 78 | 79 | 80 | from torch.nn import CrossEntropyLoss 81 | 82 | 83 | def forward_step(data_iterator, model, args, timers): 84 | """Forward step.""" 85 | 86 | # Get the batch. 87 | timers('batch generator').start() 88 | tokens, labels, image, pre_image = get_batch( 89 | data_iterator, args, timers) 90 | timers('batch generator').stop() 91 | 92 | logits = model(input_ids=tokens, image=image, pre_image=pre_image)[0] 93 | dtype = logits.dtype 94 | lm_logits = logits.to(torch.float32) 95 | 96 | # Shift so that tokens < n predict n 97 | shift_logits = lm_logits[..., :-1, :].contiguous() 98 | shift_labels = labels[..., 1:].contiguous() 99 | # Flatten the tokens 100 | loss_fct = CrossEntropyLoss(ignore_index=-100) 101 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 102 | 103 | lm_logits = lm_logits.to(dtype) 104 | loss = loss.to(dtype) 105 | return loss, {'loss': loss} 106 | 107 | 108 | from model.blip2 import BlipImageEvalProcessor 109 | from torch.utils.data import Dataset 110 | import json 111 | from PIL import Image 112 | 113 | 114 | class FewShotDataset(Dataset): 115 | def __init__(self, path, processor, tokenizer, args): 116 | max_seq_length = args.max_source_length + args.max_target_length 117 | with open(path, 'r', encoding='utf-8') as f: 118 | data = json.load(f) 119 | self.images = [] 120 | self.input_ids = [] 121 | self.labels = [] 122 | for item in data: 123 | image = processor(Image.open(item['img']).convert('RGB')) 124 | input0 = tokenizer.encode("", add_special_tokens=False) 125 | input1 = [tokenizer.pad_token_id] * args.image_length 126 | input2 = tokenizer.encode("问:" + item['prompt'] + "\n答:", add_special_tokens=False) 127 | a_ids = sum([input0, input1, input2], []) 128 | b_ids = tokenizer.encode(text=item['label'], add_special_tokens=False) 129 | if len(a_ids) > args.max_source_length - 1: 130 | a_ids = a_ids[: args.max_source_length - 1] 131 | if len(b_ids) > args.max_target_length - 2: 132 | b_ids = b_ids[: args.max_target_length - 2] 133 | pre_image = len(input0) 134 | input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) 135 | 136 | context_length = input_ids.index(tokenizer.bos_token_id) 137 | mask_position = context_length - 1 138 | labels = [-100] * context_length + input_ids[mask_position + 1:] 139 | 140 | pad_len = max_seq_length - len(input_ids) 141 | input_ids = input_ids + [tokenizer.pad_token_id] * pad_len 142 | labels = labels + [tokenizer.pad_token_id] * pad_len 143 | if args.ignore_pad_token_for_loss: 144 | labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] 145 | self.images.append(image) 146 | self.input_ids.append(input_ids) 147 | self.labels.append(labels) 148 | self.pre_image = pre_image 149 | 150 | def __len__(self): 151 | return len(self.images) 152 | 153 | def __getitem__(self, idx): 154 | return { 155 | "image": self.images[idx], 156 | "input_ids": self.input_ids[idx], 157 | "labels": self.labels[idx], 158 | "pre_image": self.pre_image 159 | } 160 | 161 | 162 | def create_dataset_function(path, args): 163 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True) 164 | image_processor = BlipImageEvalProcessor(224) 165 | 166 | dataset = FewShotDataset(path, image_processor, tokenizer, args) 167 | return dataset 168 | 169 | 170 | if __name__ == '__main__': 171 | py_parser = argparse.ArgumentParser(add_help=False) 172 | py_parser.add_argument('--max_source_length', type=int) 173 | py_parser.add_argument('--max_target_length', type=int) 174 | py_parser.add_argument('--ignore_pad_token_for_loss', type=bool, default=True) 175 | py_parser.add_argument('--source_prefix', type=str, default="") 176 | py_parser = FineTuneVisualGLMModel.add_model_specific_args(py_parser) 177 | known, args_list = py_parser.parse_known_args() 178 | args = get_args(args_list) 179 | args = argparse.Namespace(**vars(args), **vars(known)) 180 | model_type = './checkpoints' 181 | model, args = FineTuneVisualGLMModel.from_pretrained(model_type, args) 182 | tokenizer = get_tokenizer(args) 183 | label_pad_token_id = -100 184 | 185 | 186 | def data_collator(examples): 187 | for example in examples: 188 | example['input_ids'] = torch.tensor(example['input_ids'], dtype=torch.long) 189 | example['labels'] = torch.tensor(example['labels'], dtype=torch.long) 190 | ret = { 191 | 'input_ids': torch.stack([example['input_ids'] for example in examples]), 192 | 'labels': torch.stack([example['labels'] for example in examples]), 193 | 'image': torch.stack([example['image'] for example in examples]), 194 | 'pre_image': example['pre_image'] 195 | } 196 | return ret 197 | 198 | 199 | training_main(args, model_cls=model, forward_step_function=forward_step, 200 | create_dataset_function=create_dataset_function, collate_fn=data_collator) -------------------------------------------------------------------------------- /finetune_visualglm.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | NUM_WORKERS=1 3 | NUM_GPUS_PER_WORKER=8 4 | MP_SIZE=1 5 | 6 | script_path=$(realpath $0) 7 | script_dir=$(dirname $script_path) 8 | main_dir=$(dirname $script_dir) 9 | MODEL_TYPE="visualglm-6b" 10 | MODEL_ARGS="--max_source_length 64 \ 11 | --max_target_length 256 \ 12 | --lora_rank 10\ 13 | --pre_seq_len 4" 14 | 15 | # OPTIONS_SAT="SAT_HOME=$1" #"SAT_HOME=/raid/dm/sat_models" 16 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2" 17 | HOST_FILE_PATH="hostfile" 18 | HOST_FILE_PATH="hostfile_single" 19 | 20 | train_data="./fewshot-data/dataset.json" 21 | eval_data="./fewshot-data/dataset.json" 22 | 23 | 24 | gpt_options=" \ 25 | --experiment-name finetune-$MODEL_TYPE \ 26 | --model-parallel-size ${MP_SIZE} \ 27 | --mode finetune \ 28 | --train-iters 300 \ 29 | --resume-dataloader \ 30 | $MODEL_ARGS \ 31 | --train-data ${train_data} \ 32 | --valid-data ${eval_data} \ 33 | --distributed-backend nccl \ 34 | --lr-decay-style cosine \ 35 | --warmup .02 \ 36 | --checkpoint-activations \ 37 | --save-interval 300 \ 38 | --eval-interval 10000 \ 39 | --save "./checkpoints" \ 40 | --split 1 \ 41 | --eval-iters 10 \ 42 | --eval-batch-size 8 \ 43 | --zero-stage 1 \ 44 | --lr 0.0001 \ 45 | --batch-size 20 \ 46 | --skip-init \ 47 | --fp16 \ 48 | --use_lora 49 | " 50 | 51 | 52 | 53 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_visualglm.py ${gpt_options}" 54 | echo ${run_cmd} 55 | eval ${run_cmd} 56 | 57 | set +x 58 | -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Pillars-Creation/Visualglm-image-to-text/120b78d155a278c9e98d5a99ba9d3ad572b413bd/img.png -------------------------------------------------------------------------------- /lora_mixin.py: -------------------------------------------------------------------------------- 1 | """ 2 | In this mixin, I use a different implementation than sat/model/finetune/lora.py 3 | I just use a fake linear layer to replace any model with lora mixin. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from sat.model.base_model import BaseMixin 9 | import math 10 | from sat.helpers import print_all 11 | from sat.model.transformer import RowParallelLinear, ColumnParallelLinear 12 | 13 | 14 | class HackLinear(nn.Linear): 15 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 16 | error_msgs): 17 | if prefix + 'weight' in state_dict: 18 | self.weight.data.copy_(state_dict[prefix + 'weight']) 19 | if prefix + 'bias' in state_dict: 20 | self.bias.data.copy_(state_dict[prefix + 'bias']) 21 | 22 | 23 | class HackRowParallelLinear(RowParallelLinear): 24 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 25 | error_msgs): 26 | if prefix + 'weight' in state_dict: 27 | self.weight.data.copy_(state_dict[prefix + 'weight']) 28 | if prefix + 'bias' in state_dict: 29 | self.bias.data.copy_(state_dict[prefix + 'bias']) 30 | 31 | 32 | class HackColumnParallelLinear(ColumnParallelLinear): 33 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 34 | error_msgs): 35 | if prefix + 'weight' in state_dict: 36 | self.weight.data.copy_(state_dict[prefix + 'weight']) 37 | if prefix + 'bias' in state_dict: 38 | self.bias.data.copy_(state_dict[prefix + 'bias']) 39 | 40 | 41 | try: 42 | from bitsandbytes.nn import LinearNF4 43 | 44 | 45 | def copy_nested_list(src, dst): 46 | for i in range(len(dst)): 47 | if type(dst[i]) is torch.Tensor: 48 | dst[i].copy_(src[i]) 49 | elif type(dst[i]) is list: 50 | copy_nested_list(src[i], dst[i]) 51 | else: 52 | dst[i] = src[i] 53 | 54 | 55 | class HackLinearNF4(LinearNF4): 56 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 57 | error_msgs): 58 | if prefix + 'weight' in state_dict: 59 | self.weight.data.copy_(state_dict[prefix + 'weight']) 60 | if self.weight.data.dtype == torch.uint8: 61 | copy_nested_list(state_dict[prefix + 'quant_state'], self.weight.quant_state) 62 | if prefix + 'bias' in state_dict: 63 | self.bias.data.copy_(state_dict[prefix + 'bias']) 64 | 65 | def _save_to_state_dict(self, destination, prefix, keep_vars): 66 | super()._save_to_state_dict(destination, prefix, keep_vars) 67 | destination[prefix + 'quant_state'] = self.weight.quant_state 68 | except Exception as exception: 69 | print_all("Failed to load bitsandbytes:" + str(exception), level='WARNING') 70 | 71 | 72 | class HackParameterList(nn.ParameterList): 73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 74 | error_msgs): 75 | for i in range(len(self)): 76 | if prefix + str(i) in state_dict: 77 | self[i].data.copy_(state_dict[prefix + str(i)]) 78 | 79 | 80 | map_cls = { 81 | nn.Linear: (HackLinear, {}), 82 | ColumnParallelLinear: (HackColumnParallelLinear, {'gather_output': False}), 83 | RowParallelLinear: (HackRowParallelLinear, {'input_is_parallel': True}) 84 | } 85 | 86 | 87 | class LoraLinear(nn.Module): 88 | def __init__(self, original_cls, partition, in_dim, out_dim, r, lora_alpha=1., lora_dropout=0., head_first=False, 89 | num_attention_heads=None, hidden_size_per_attention_head=None, qlora=False): 90 | """ 91 | You can use safely with this layer, ONLY WHEN query_key_value output is query_key_value order. 92 | If you use a different order like ChatGLM 93 | """ 94 | super().__init__() 95 | if lora_dropout and lora_dropout > 0: 96 | self.lora_dropout = nn.Dropout(p=lora_dropout) 97 | else: 98 | self.lora_dropout = lambda x: x 99 | self.r = r 100 | self.lora_alpha = lora_alpha 101 | self.scaling = self.lora_alpha / self.r 102 | if qlora: 103 | try: 104 | self.original = HackLinearNF4(in_dim, out_dim) 105 | except: 106 | raise Exception( 107 | 'Build 4bit layer failed. You need to install the latest bitsandbytes. Try `pip install bitsandbytes`. If you still meet error after installation, try running `from bitsandbytes.nn import LinearNF4` with python and fix the error.') 108 | else: 109 | base_cls, kwargs = map_cls[original_cls] 110 | self.original = base_cls(in_dim, out_dim, **kwargs) 111 | self.matrix_A = HackParameterList([nn.Parameter(torch.empty((r, in_dim))) for _ in range(partition)]) 112 | self.matrix_B = HackParameterList( 113 | [nn.Parameter(torch.empty((out_dim // partition, r))) for _ in range(partition)]) 114 | for i in range(partition): 115 | nn.init.kaiming_uniform_(self.matrix_A[i], a=math.sqrt(5)) 116 | nn.init.zeros_(self.matrix_B[i]) 117 | self.head_first = head_first 118 | self.partition = partition 119 | if head_first: 120 | assert num_attention_heads is not None and hidden_size_per_attention_head is not None, "You should set num_attention_heads and hidden_size_per_attention_head if you use head_first=True!" 121 | self.num_attention_heads = num_attention_heads 122 | self.hidden_size_per_attention_head = hidden_size_per_attention_head 123 | 124 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 125 | error_msgs): 126 | # This is not a perfect version, becuase it doesn't handle errors and unexpected keys. 127 | if prefix + 'weight' in state_dict: 128 | # load from normal Linear 129 | self.original._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, 130 | unexpected_keys, error_msgs) 131 | else: 132 | # load from LoraLinear 133 | super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, 134 | error_msgs) 135 | 136 | def forward(self, x): 137 | mixed_raw_layer = self.original(x) 138 | lora_outputs = [] 139 | for i in range(self.partition): 140 | lora_outputs.append((self.lora_dropout(x) @ self.matrix_A[i].T @ self.matrix_B[i].T) * self.scaling) 141 | if self.head_first: 142 | new_tensor_shape = lora_outputs[0].size()[:-1] + ( 143 | self.num_attention_heads, 144 | self.hidden_size_per_attention_head, 145 | ) 146 | for i in range(self.partition): 147 | lora_outputs[i] = lora_outputs[i].view(*new_tensor_shape) 148 | mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1).view(*mixed_raw_layer.size()) 149 | else: 150 | mixed_raw_layer = mixed_raw_layer + torch.cat(lora_outputs, -1) 151 | 152 | return mixed_raw_layer 153 | 154 | 155 | def replace_linear_with_lora(lin, partition, r, *args, **kw_args): 156 | # not supported for linear without bias for now 157 | out_dim, in_dim = lin.weight.shape 158 | original_cls = type(lin) 159 | del lin 160 | return LoraLinear(original_cls, partition, in_dim, out_dim, r, *args, **kw_args) 161 | 162 | 163 | def merge_linear_lora(lin): 164 | if lin.original.weight.data.dtype is not torch.uint8: 165 | weight = lin.original.weight 166 | out_dim, in_dim = weight.shape 167 | new_lin = nn.Linear(in_dim, out_dim) 168 | else: 169 | import bitsandbytes.functional as F 170 | weight = F.dequantize_fp4(lin.original.weight.data, lin.original.weight.quant_state).to( 171 | lin.original.bias.data.dtype) 172 | out_dim, in_dim = weight.shape 173 | new_lin = HackLinearNF4(in_dim, out_dim) 174 | new_lin.bias.data = lin.original.bias.data 175 | new_qkv = [] 176 | for i in range(lin.partition): 177 | new_qkv.append(lin.matrix_A[i].data.T.float() @ lin.matrix_B[i].data.T.float() * lin.scaling) 178 | if lin.head_first: 179 | ini_shape = new_qkv[0].shape 180 | new_qkv = [x.view(ini_shape[0], lin.num_attention_heads, -1) for x in new_qkv] 181 | new_qkv = torch.cat(new_qkv, -1).view(ini_shape[0], lin.partition * ini_shape[1]) 182 | else: 183 | new_qkv = torch.cat(new_qkv, -1) 184 | new_lin.weight.data = weight + new_qkv.T.to(lin.original.bias.data.dtype) 185 | return new_lin.cuda() if torch.cuda.is_available() else new_lin 186 | 187 | 188 | class LoraMixin(BaseMixin): 189 | def __init__(self, 190 | layer_num, 191 | r: int = 0, 192 | lora_alpha: int = 1, 193 | lora_dropout: float = 0., 194 | layer_range=None, 195 | head_first=False, 196 | num_attention_heads=None, 197 | hidden_size_per_attention_head=None, 198 | qlora=False, 199 | cross_attention=True): 200 | super().__init__() 201 | self.r = r 202 | self.lora_alpha = lora_alpha 203 | self.lora_dropout = lora_dropout 204 | 205 | if layer_range is None: 206 | layer_range = [i for i in range(layer_num)] 207 | self.layer_range = layer_range 208 | 209 | self.scaling = self.lora_alpha / self.r 210 | self.head_first = head_first 211 | self.num_attention_heads = num_attention_heads 212 | self.hidden_size_per_attention_head = hidden_size_per_attention_head 213 | self.qlora = qlora 214 | self.cross_attention = cross_attention 215 | 216 | def reinit(self, parent_model): 217 | for i in self.layer_range: 218 | print(f'replacing layer {i} attention with lora') 219 | parent_model.transformer.layers[i].attention.dense = replace_linear_with_lora( 220 | parent_model.transformer.layers[i].attention.dense, 1, self.r, self.lora_alpha, self.lora_dropout, 221 | qlora=self.qlora) 222 | parent_model.transformer.layers[i].attention.query_key_value = replace_linear_with_lora( 223 | parent_model.transformer.layers[i].attention.query_key_value, 3, self.r, self.lora_alpha, 224 | self.lora_dropout, head_first=self.head_first, num_attention_heads=self.num_attention_heads, 225 | hidden_size_per_attention_head=self.hidden_size_per_attention_head, qlora=self.qlora) 226 | if self.cross_attention and parent_model.transformer.layers[i].is_decoder: 227 | print(f'replacing layer {i} cross attention with lora') 228 | parent_model.transformer.layers[i].cross_attention.dense = replace_linear_with_lora( 229 | parent_model.transformer.layers[i].cross_attention.dense, 1, self.r, self.lora_alpha, 230 | self.lora_dropout, qlora=self.qlora) 231 | parent_model.transformer.layers[i].cross_attention.query = replace_linear_with_lora( 232 | parent_model.transformer.layers[i].cross_attention.query, 1, self.r, self.lora_alpha, 233 | self.lora_dropout, qlora=self.qlora) 234 | parent_model.transformer.layers[i].cross_attention.key_value = replace_linear_with_lora( 235 | parent_model.transformer.layers[i].cross_attention.key_value, 2, self.r, self.lora_alpha, 236 | self.lora_dropout, qlora=self.qlora) 237 | if self.qlora: 238 | print('replacing chatglm linear layer with 4bit') 239 | 240 | def replace_linear_with_nf4(model, name=None, cache={}): 241 | if type(model) in (nn.Linear, RowParallelLinear, ColumnParallelLinear): 242 | out_dim, in_dim = model.weight.shape 243 | return HackLinearNF4(in_dim, out_dim) 244 | names = set() 245 | for name, child in model.named_children(): 246 | if name not in names: 247 | if child in cache: 248 | new_child = cache[child] 249 | else: 250 | new_child = replace_linear_with_nf4(child, name=name, cache=cache) 251 | cache[child] = new_child 252 | setattr(model, name, new_child) 253 | names.add(name) 254 | flag = True 255 | while flag: 256 | flag = False 257 | for name, child in model.named_children(): 258 | if name not in names: 259 | setattr(model, name, cache[child]) 260 | names.add(name) 261 | flag = True 262 | return model 263 | 264 | replace_linear_with_nf4(parent_model.transformer, None, {}) 265 | 266 | def merge_lora(self): 267 | for i in self.layer_range: 268 | print(f'merge layer {i} lora attention back to linear') 269 | self.transformer.layers[i].attention.dense = merge_linear_lora(self.transformer.layers[i].attention.dense) 270 | self.transformer.layers[i].attention.query_key_value = merge_linear_lora( 271 | self.transformer.layers[i].attention.query_key_value) 272 | if self.transformer.layers[i].is_decoder: 273 | print(f'merge layer {i} lora cross attention back to linear') 274 | self.transformer.layers[i].cross_attention.dense = merge_linear_lora( 275 | self.transformer.layers[i].cross_attention.dense) 276 | self.transformer.layers[i].cross_attention.query = merge_linear_lora( 277 | self.transformer.layers[i].cross_attention.query) 278 | self.transformer.layers[i].cross_attention.key_value = merge_linear_lora( 279 | self.transformer.layers[i].cross_attention.key_value) 280 | 281 | 282 | if __name__ == '__main__': 283 | class Model(nn.Module): 284 | def __init__(self): 285 | super().__init__() 286 | self.child = nn.Linear(100, 200) 287 | 288 | def forward(self, x): 289 | return self.child(x) 290 | 291 | 292 | model = Model() 293 | torch.save(model.state_dict(), "linear.pt") 294 | x = torch.randn(2, 100) 295 | out1 = model(x) 296 | model.child = LoraLinear(100, 200, 10) 297 | model.load_state_dict(torch.load("linear.pt"), strict=False) 298 | out2 = model(x) 299 | torch.save(model.state_dict(), "lora.pt") 300 | ckpt = torch.load("lora.pt") 301 | breakpoint() 302 | model.load_state_dict(ckpt, strict=False) 303 | out3 = model(x) 304 | breakpoint() -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .chat import chat 2 | from .infer_util import * 3 | from .blip2 import BlipImageEvalProcessor 4 | -------------------------------------------------------------------------------- /model/blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from sat.model import ViTModel, BaseModel 5 | from sat.model import BaseMixin 6 | from sat import AutoModel 7 | from copy import deepcopy 8 | from torchvision import transforms 9 | from torchvision.transforms.functional import InterpolationMode 10 | 11 | class LNFinalyMixin(BaseMixin): 12 | def __init__(self, hidden_size): 13 | super().__init__() 14 | self.ln_vision = nn.LayerNorm(hidden_size) 15 | 16 | def final_forward(self, logits, **kw_args): 17 | return self.ln_vision(logits) 18 | 19 | class EVAViT(ViTModel): 20 | def __init__(self, args, transformer=None, parallel_output=True, **kwargs): 21 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, **kwargs) 22 | self.del_mixin("cls") 23 | self.add_mixin("cls", LNFinalyMixin(args.hidden_size)) 24 | 25 | def forward(self, image): 26 | batch_size = image.size(0) 27 | input_ids = torch.zeros(batch_size, 1, dtype=torch.long, device=image.device) 28 | attention_mask = torch.tensor([[1.]], dtype=image.dtype, device=image.device) 29 | return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, image=image) 30 | 31 | class QFormer(BaseModel): 32 | def __init__(self, args, transformer=None, parallel_output=True, **kwargs): 33 | super().__init__(args, transformer=transformer, parallel_output=parallel_output, activation_func=nn.functional.gelu, **kwargs) 34 | self.transformer.position_embeddings = None 35 | 36 | def final_forward(self, logits, **kw_args): 37 | return logits 38 | 39 | def position_embedding_forward(self, position_ids, **kw_args): 40 | return None 41 | 42 | def forward(self, encoder_outputs): 43 | batch_size = encoder_outputs.size(0) 44 | input_ids = torch.arange(32, dtype=torch.long, device=encoder_outputs.device).unsqueeze(0).expand(batch_size, -1) 45 | attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) 46 | cross_attention_mask = torch.tensor([[1.]], dtype=encoder_outputs.dtype, device=encoder_outputs.device) 47 | return super().forward(input_ids=input_ids, position_ids=None, attention_mask=attention_mask, encoder_outputs=encoder_outputs, cross_attention_mask=cross_attention_mask) 48 | 49 | 50 | class BLIP2(torch.nn.Module): 51 | def __init__(self, eva_args, qformer_args, vit=None, qformer=None, **kwargs): 52 | super().__init__() 53 | if vit is not None: 54 | self.vit = vit 55 | else: 56 | self.vit = EVAViT(EVAViT.get_args(**eva_args)) 57 | if qformer is not None: 58 | self.qformer = qformer 59 | else: 60 | self.qformer = QFormer(QFormer.get_args(**qformer_args)) 61 | 62 | self.glm_proj = nn.Linear(768, 4096).to(self.qformer.parameters().__next__().device).to(self.qformer.parameters().__next__().dtype) 63 | 64 | def forward(self, image, **kwargs): 65 | enc = self.vit(image)[0] 66 | out = self.qformer(enc)[0] 67 | return self.glm_proj(out) 68 | 69 | class BlipImageBaseProcessor(): 70 | def __init__(self, mean=None, std=None): 71 | if mean is None: 72 | mean = (0.48145466, 0.4578275, 0.40821073) 73 | if std is None: 74 | std = (0.26862954, 0.26130258, 0.27577711) 75 | 76 | self.normalize = transforms.Normalize(mean, std) 77 | 78 | class BlipImageEvalProcessor(BlipImageBaseProcessor): 79 | def __init__(self, image_size=384, mean=None, std=None): 80 | super().__init__(mean=mean, std=std) 81 | 82 | self.transform = transforms.Compose( 83 | [ 84 | transforms.Resize( 85 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 86 | ), 87 | transforms.ToTensor(), 88 | self.normalize, 89 | ] 90 | ) 91 | 92 | def __call__(self, item): 93 | return self.transform(item) 94 | -------------------------------------------------------------------------------- /model/chat.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : chat.py 4 | @Time : 2023/05/08 19:10:08 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | import os 10 | import sys 11 | import re 12 | from functools import partial 13 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | import torch 19 | from sat.generation.autoregressive_sampling import filling_sequence, BaseStrategy 20 | 21 | from .blip2 import BlipImageEvalProcessor 22 | 23 | def get_masks_and_position_ids_glm(seq, mask_position, context_length): 24 | '''GLM model, different from GPT. 25 | Args: 26 | seq: torch.IntTensor, [seq_len] 27 | mask_position: int, the position of the masked place. 28 | context_length: int, the length of context. 29 | Returns: 30 | tokens: torch.IntTensor, [1, seq_len] 31 | attention_mask: torch.FloatTensor, [1, seq_len, seq_len] 32 | position_ids: torch.IntTensor, [2, seq_len] 33 | ''' 34 | tokens = seq.unsqueeze(0) 35 | 36 | attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) 37 | attention_mask.tril_() 38 | attention_mask[..., :context_length] = 1 39 | attention_mask.unsqueeze_(1) 40 | 41 | # 2D position ids 42 | position_ids = torch.zeros(2, len(seq), device=tokens.device, dtype=torch.long) 43 | torch.arange(0, context_length, out=position_ids[0, :context_length]) 44 | position_ids[0, context_length:] = mask_position 45 | torch.arange(1, len(seq) - context_length + 1, out=position_ids[1, context_length:]) 46 | 47 | position_ids = position_ids.unsqueeze(0) 48 | return tokens, attention_mask, position_ids 49 | 50 | def process_response(response): 51 | response = response.strip() 52 | response = response.replace("[[训练时间]]", "2023年") 53 | punkts = [ 54 | [",", ","], 55 | ["!", "!"], 56 | [":", ":"], 57 | [";", ";"], 58 | ["\?", "?"], 59 | ] 60 | for item in punkts: 61 | response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response) 62 | response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response) 63 | return response 64 | 65 | def process_image(text, image=None): 66 | '''Process image in text. 67 | Args: 68 | text: str, text. 69 | image: Optional, image path / url / PIL image. 70 | ''' 71 | image_position = text.rfind("") + 5 72 | # extract path from using re 73 | image_path = re.findall(r"(.*?)", text) 74 | image_path = image_path[-1] if image_path[-1] else None 75 | if image_path is not None: 76 | assert image is None, "image and image_path cannot be both not None." 77 | text = text.replace(image_path, "") 78 | image_path = image_path.strip() 79 | # url 80 | if image_path.startswith("http"): 81 | response = requests.get(image_path, timeout=10) 82 | image = Image.open(BytesIO(response.content)) 83 | # local path 84 | else: 85 | image = Image.open(image_path) 86 | if image is not None and isinstance(image, Image.Image): 87 | processor = BlipImageEvalProcessor(224) 88 | image = processor(image.convert('RGB')) 89 | image = image.unsqueeze(0) 90 | return text, image_position, image 91 | 92 | 93 | def chat(image_path, model, tokenizer, 94 | query: str, history: List[Tuple[str, str]] = None, image: Image = None, 95 | max_length: int = 1024, top_p=0.7, top_k=30, temperature=0.95, repetition_penalty=1.2, 96 | invalid_slices=[], english=False 97 | ): 98 | if not history: 99 | history = [] 100 | if image_path: 101 | prompt = "{}".format(image_path if image_path else "") 102 | else: 103 | prompt = "" 104 | if english: 105 | for i, (old_query, response) in enumerate(history): 106 | prompt += "Q:{}\nA:{}\n".format(old_query, response) 107 | prompt += "Q:{}\nA:".format(query) 108 | else: 109 | for i, (old_query, response) in enumerate(history): 110 | prompt += "问:{}\n答:{}\n".format(old_query, response) 111 | prompt += "问:{}\n答:".format(query) 112 | # --------------- 113 | # tokenizer, this is an example of huggingface tokenizer. 114 | # input str, output['input_ids'] = tensor([[tokenized str, gmask, sop]]) 115 | prompt, image_position, torch_image = process_image(prompt, image=image) 116 | if torch_image is not None: 117 | torch_image = torch_image.to(next(model.parameters()).dtype).to(next(model.parameters()).device) 118 | if image_position < 5: # no image 119 | inputs = tokenizer([prompt], return_tensors="pt").to(model.parameters().__next__().device)['input_ids'][0] 120 | pre_image = 0 121 | else: 122 | input0 = tokenizer.encode(prompt[:image_position], add_special_tokens=False) 123 | input1 = [tokenizer.pad_token_id] * model.image_length 124 | input2 = tokenizer.encode(prompt[image_position:], add_special_tokens=False) 125 | inputs = sum([input0, input1, input2], []) 126 | inputs = torch.tensor(tokenizer.build_inputs_with_special_tokens(inputs)).to(model.parameters().__next__().device) 127 | pre_image = len(input0) 128 | # --------------- 129 | # Next, we manually set the format to keep flexibility. 130 | mask_position = len(inputs) - 2 131 | context_length = len(inputs) - 1 # all before sop 132 | get_func = partial(get_masks_and_position_ids_glm, mask_position=mask_position, context_length=context_length) 133 | seq = torch.cat( 134 | [inputs, torch.tensor([-1]*(max_length-len(inputs)), device=inputs.device)], dim=0 135 | ) 136 | # --------------- 137 | # from sat.generation.sampling_strategies import BeamSearchStrategy 138 | # strategy = BeamSearchStrategy(num_beams, length_penalty=1., prefer_min_length=5, end_tokens=[tokenizer.eos_token_id], consider_end=True, no_repeat_ngram_size=5, stop_n_iter_unchanged=30, temperature=temperature, top_p=top_p, top_k=60, repetition_penalty=1.1) 139 | strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id], 140 | invalid_slices=invalid_slices, repetition_penalty=repetition_penalty) 141 | output = filling_sequence( 142 | model, seq, 143 | batch_size=1, 144 | get_masks_and_position_ids=get_func, 145 | strategy=strategy, 146 | pre_image=pre_image, 147 | image=torch_image, 148 | )[0] # drop memory 149 | 150 | # --------------- 151 | # port from inference_glm.py, more general than chat mode 152 | # clip -1s and fill back generated things into seq 153 | if type(output) is not list: 154 | output_list = output.tolist() 155 | else: 156 | output_list = output 157 | for i in range(len(output_list)): 158 | output = output_list[i] 159 | if type(output) is not list: 160 | output = output.tolist() 161 | try: 162 | unfinished = output.index(-1) 163 | except ValueError: 164 | unfinished = len(output) 165 | if output[unfinished - 1] == tokenizer.eos_token_id: 166 | unfinished -= 1 167 | bog = output.index(tokenizer.bos_token_id) 168 | output_list[i] = output[:mask_position] + output[bog + 1:unfinished] + output[mask_position + 1:bog] 169 | # --------------- 170 | 171 | response = tokenizer.decode(output_list[0]) 172 | sep = 'A:' if english else '答:' 173 | response = process_response(response).split(sep)[-1].strip() 174 | history = history + [(query, response)] 175 | return response, history, torch_image 176 | -------------------------------------------------------------------------------- /model/infer_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from io import BytesIO 4 | import base64 5 | import re 6 | import argparse 7 | import torch 8 | from transformers import AutoTokenizer 9 | from sat.model.mixins import CachedAutoregressiveMixin 10 | from sat.quantization.kernels import quantize 11 | import hashlib 12 | from .visualglm import VisualGLMModel 13 | 14 | def get_infer_setting(gpu_device=0, quant=None): 15 | os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_device) 16 | args = argparse.Namespace( 17 | fp16=True, 18 | skip_init=True, 19 | device='cuda' if quant is None else 'cpu', 20 | ) 21 | model, args = VisualGLMModel.from_pretrained('visualglm-6b', args) 22 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 23 | assert quant in [None, 4, 8] 24 | if quant is not None: 25 | quantize(model.transformer, quant) 26 | model.eval() 27 | model = model.cuda() 28 | tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True) 29 | return model, tokenizer 30 | 31 | def is_chinese(text): 32 | zh_pattern = re.compile(u'[\u4e00-\u9fa5]+') 33 | return zh_pattern.search(text) 34 | 35 | def generate_input(input_text, input_image_prompt, history=[], input_para=None, image_is_encoded=True): 36 | if not image_is_encoded: 37 | image = input_image_prompt 38 | else: 39 | decoded_image = base64.b64decode(input_image_prompt) 40 | image = Image.open(BytesIO(decoded_image)) 41 | 42 | input_data = {'input_query': input_text, 'input_image': image, 'history': history, 'gen_kwargs': input_para} 43 | return input_data 44 | 45 | 46 | def process_image(image_encoded): 47 | decoded_image = base64.b64decode(image_encoded) 48 | image = Image.open(BytesIO(decoded_image)) 49 | image_hash = hashlib.sha256(image.tobytes()).hexdigest() 50 | image_path = f'./examples/{image_hash}.png' 51 | if not os.path.isfile(image_path): 52 | image.save(image_path) 53 | return os.path.abspath(image_path) -------------------------------------------------------------------------------- /model/visualglm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sat.model.official import ChatGLMModel 3 | from sat.model.base_model import BaseMixin 4 | from copy import deepcopy 5 | import json 6 | from .blip2 import BLIP2 7 | 8 | from sat.resources.urls import MODEL_URLS 9 | MODEL_URLS['visualglm-6b'] = 'https://cloud.tsinghua.edu.cn/f/348b98dffcc940b6a09d/?dl=1' 10 | 11 | class ImageMixin(BaseMixin): 12 | def __init__(self, args): 13 | super().__init__() 14 | self.args = deepcopy(args) 15 | self.model = BLIP2(args.eva_args, args.qformer_args) 16 | 17 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args): 18 | if kw_args["pre_image"] > input_ids.shape[1] or kw_args.get("image", None) is None: 19 | return self.transformer.word_embeddings(input_ids) 20 | image_emb = self.model(**kw_args) 21 | # the image is inserted after 问:, override 32 pads 22 | pre_id, pads, post_id = torch.tensor_split(input_ids, [kw_args["pre_image"], kw_args["pre_image"]+self.args.image_length], dim=1) 23 | pre_txt_emb = self.transformer.word_embeddings(pre_id) 24 | post_txt_emb = self.transformer.word_embeddings(post_id) 25 | return torch.cat([pre_txt_emb, image_emb, post_txt_emb], dim=1) 26 | 27 | class VisualGLMModel(ChatGLMModel): 28 | def __init__(self, args, transformer=None, **kwargs): 29 | super().__init__(args, transformer=transformer, **kwargs) 30 | self.image_length = args.image_length 31 | self.add_mixin("eva", ImageMixin(args)) 32 | 33 | @classmethod 34 | def add_model_specific_args(cls, parser): 35 | group = parser.add_argument_group('VisualGLM', 'VisualGLM Configurations') 36 | group.add_argument('--image_length', type=int, default=32) 37 | group.add_argument('--eva_args', type=json.loads, default={}) 38 | group.add_argument('--qformer_args', type=json.loads, default={}) 39 | return super().add_model_specific_args(parser) 40 | 41 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModel 2 | tokenizer = AutoTokenizer.from_pretrained("../visualglm-6b", trust_remote_code=True) 3 | model = AutoModel.from_pretrained("../visualglm-6b", trust_remote_code=True).half().cuda() 4 | image_path = "./fewshot-data/男性面相2.jpeg" 5 | print('输入:',image_path) 6 | response, history = model.chat(tokenizer, image_path, "从眼睛看这张照片可能的面相?", history=[]) 7 | print("从眼睛看这张照片可能的面相:"+response) 8 | response, history = model.chat(tokenizer, image_path, "从鼻子描述这张图片人物可能的面相?", history=[]) 9 | print("从鼻子描述这张图片人物可能的面相:"+response) 10 | response, history = model.chat(tokenizer, image_path, "从嘴巴描述这张图片人物可能的面相?", history=[]) 11 | print("从嘴巴描述这张图片人物可能的面相:"+response) 12 | response, history = model.chat(tokenizer, image_path, "从天庭描述这张图片人物可能的面相?", history=[]) 13 | print("从天庭描述这张图片人物可能的面相:"+response) 14 | response, history = model.chat(tokenizer, image_path, "这个人的精神状态怎么样", history=history) 15 | print('这个人的精神状态怎么样',response) 16 | # response, history = model.chat(tokenizer, image_path, "喜欢这张图片可能是什么样的年龄性别职业", history=history) 17 | # print('什么样的人会喜欢:',response) 18 | -------------------------------------------------------------------------------- /predict_lora.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | 3 | import os 4 | import sys 5 | import torch 6 | import argparse 7 | from transformers import AutoTokenizer 8 | from sat.model.mixins import CachedAutoregressiveMixin 9 | from sat.quantization.kernels import quantize 10 | 11 | from model import VisualGLMModel, chat 12 | from finetune_visualglm import FineTuneVisualGLMModel 13 | from sat.model import AutoModel 14 | 15 | 16 | def main(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') 19 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') 20 | parser.add_argument("--top_k", type=int, default=100, help='top k for top k sampling') 21 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') 22 | parser.add_argument("--english", action='store_true', help='only output English') 23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') 24 | parser.add_argument("--from_pretrained", type=str, default="checkpoints/finetune-visualglm-6b-06-06-00-05/", 25 | help='pretrained ckpt') 26 | args = parser.parse_args() 27 | 28 | # load model 29 | model, model_args = AutoModel.from_pretrained( 30 | args.from_pretrained, 31 | args=argparse.Namespace( 32 | fp16=True, 33 | skip_init=True, 34 | use_gpu_initialization=True, 35 | device='cuda' 36 | )) 37 | model = model.to(torch.float32) 38 | model = model.eval() 39 | 40 | if args.quant: 41 | quantize(model.transformer, args.quant) 42 | 43 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 44 | tokenizer = AutoTokenizer.from_pretrained("../chatglm-6b", trust_remote_code=True) 45 | with torch.no_grad(): 46 | history = None 47 | cache_image = None 48 | image_path = './fewshot-data/虎眼.jpeg' 49 | query = '从鼻子描述这张图片人物可能的面相?' 50 | response, history, cache_image = chat( 51 | image_path, 52 | model, 53 | tokenizer, 54 | query, 55 | history=history, 56 | image=cache_image, 57 | max_length=args.max_length, 58 | top_p=args.top_p, 59 | temperature=args.temperature, 60 | top_k=args.top_k, 61 | english=args.english, 62 | invalid_slices=[slice(63823, 130000)] if args.english else [] 63 | ) 64 | sep = 'A:' if args.english else '答:' 65 | print(query + ': ' + response.split(sep)[-1].strip()) 66 | image_path = None 67 | 68 | query = '从眼睛看这张照片可能的面相?' 69 | response, history, cache_image = chat( 70 | image_path, 71 | model, 72 | tokenizer, 73 | query, 74 | history=history, 75 | image=cache_image, 76 | max_length=args.max_length, 77 | top_p=args.top_p, 78 | temperature=args.temperature, 79 | top_k=args.top_k, 80 | english=args.english, 81 | invalid_slices=[slice(63823, 130000)] if args.english else [] 82 | ) 83 | sep = 'A:' if args.english else '答:' 84 | print(query + ': ' + response.split(sep)[-1].strip()) 85 | 86 | query = '从嘴巴描述这张图片人物可能的面相?' 87 | response, history, cache_image = chat( 88 | image_path, 89 | model, 90 | tokenizer, 91 | query, 92 | history=history, 93 | image=cache_image, 94 | max_length=args.max_length, 95 | top_p=args.top_p, 96 | temperature=args.temperature, 97 | top_k=args.top_k, 98 | english=args.english, 99 | invalid_slices=[slice(63823, 130000)] if args.english else [] 100 | ) 101 | sep = 'A:' if args.english else '答:' 102 | print(query + ': ' + response.split(sep)[-1].strip()) 103 | 104 | query = '从天庭描述这张图片人物可能的面相?' 105 | response, history, cache_image = chat( 106 | image_path, 107 | model, 108 | tokenizer, 109 | query, 110 | history=history, 111 | image=cache_image, 112 | max_length=args.max_length, 113 | top_p=args.top_p, 114 | temperature=args.temperature, 115 | top_k=args.top_k, 116 | english=args.english, 117 | invalid_slices=[slice(63823, 130000)] if args.english else [] 118 | ) 119 | sep = 'A:' if args.english else '答:' 120 | print(query + ': ' + response.split(sep)[-1].strip()) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | 126 | # python predict_lora.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SwissArmyTransformer>=0.3.6 2 | torch>1.10.0 3 | torchvision 4 | transformers>=4.27.1 5 | mdtex2html 6 | gradio -------------------------------------------------------------------------------- /requirements_wo_ds.txt: -------------------------------------------------------------------------------- 1 | torch>1.10.0 2 | torchvision 3 | transformers>=4.27.1 4 | mdtex2html 5 | gradio 6 | sentencepiece 7 | tensorboardX 8 | datasets 9 | cpm_kernels 10 | einops -------------------------------------------------------------------------------- /web_demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gradio as gr 4 | from PIL import Image 5 | import os 6 | import json 7 | from model import is_chinese, get_infer_setting, generate_input, chat 8 | 9 | def generate_text_with_image(input_text, image, history=[], request_data=dict(), is_zh=True): 10 | input_para = { 11 | "max_length": 2048, 12 | "min_length": 50, 13 | "temperature": 0.8, 14 | "top_p": 0.4, 15 | "top_k": 100, 16 | "repetition_penalty": 1.2 17 | } 18 | input_para.update(request_data) 19 | 20 | input_data = generate_input(input_text, image, history, input_para, image_is_encoded=False) 21 | input_image, gen_kwargs = input_data['input_image'], input_data['gen_kwargs'] 22 | answer, history, _ = chat(None, model, tokenizer, input_text, history=history, image=input_image, \ 23 | max_length=gen_kwargs['max_length'], top_p=gen_kwargs['top_p'], \ 24 | top_k = gen_kwargs['top_k'], temperature=gen_kwargs['temperature'], english=not is_zh) 25 | return answer 26 | 27 | 28 | def request_model(input_text, temperature, top_p, image_prompt, result_previous): 29 | result_text = [(ele[0], ele[1]) for ele in result_previous] 30 | for i in range(len(result_text)-1, -1, -1): 31 | if result_text[i][0] == "" or result_text[i][1] == "": 32 | del result_text[i] 33 | print(f"history {result_text}") 34 | 35 | is_zh = is_chinese(input_text) 36 | if image_prompt is None: 37 | if is_zh: 38 | result_text.append((input_text, '图片为空!请上传图片并重试。')) 39 | else: 40 | result_text.append((input_text, 'Image empty! Please upload a image and retry.')) 41 | return input_text, result_text 42 | elif input_text == "": 43 | result_text.append((input_text, 'Text empty! Please enter text and retry.')) 44 | return "", result_text 45 | 46 | request_para = {"temperature": temperature, "top_p": top_p} 47 | image = Image.open(image_prompt) 48 | try: 49 | answer = generate_text_with_image(input_text, image, result_text.copy(), request_para, is_zh) 50 | except Exception as e: 51 | print(f"error: {e}") 52 | if is_zh: 53 | result_text.append((input_text, '超时!请稍等几分钟再重试。')) 54 | else: 55 | result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.')) 56 | return "", result_text 57 | 58 | result_text.append((input_text, answer)) 59 | print(result_text) 60 | return "", result_text 61 | 62 | 63 | DESCRIPTION = '''# VisualGLM''' 64 | 65 | MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.\nHint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.' 66 | MAINTENANCE_NOTICE2 = '提示1: 如果应用报了“Something went wrong, connection error out”的错误,请关闭代理并重试。\n提示2: 如果你上传了很大的图片,比如10MB大小,那将需要一些时间来上传和处理,请耐心等待。' 67 | 68 | NOTES = 'This app is adapted from https://github.com/THUDM/VisualGLM-6B. It would be recommended to check out the repo if you want to see the detail of our model and training process.' 69 | 70 | 71 | def clear_fn(value): 72 | return "", [("", "Hi, What do you want to know about this image?")], None 73 | 74 | def clear_fn2(value): 75 | return [("", "Hi, What do you want to know about this image?")] 76 | 77 | 78 | def main(args): 79 | gr.close_all() 80 | global model, tokenizer 81 | model, tokenizer = get_infer_setting(gpu_device=0, quant=args.quant) 82 | 83 | with gr.Blocks(css='style.css') as demo: 84 | gr.Markdown(DESCRIPTION) 85 | with gr.Row(): 86 | with gr.Column(scale=4.5): 87 | with gr.Group(): 88 | input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.') 89 | with gr.Row(): 90 | run_button = gr.Button('Generate') 91 | clear_button = gr.Button('Clear') 92 | 93 | image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None) 94 | with gr.Row(): 95 | temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature') 96 | top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P') 97 | with gr.Group(): 98 | with gr.Row(): 99 | maintenance_notice = gr.Markdown(MAINTENANCE_NOTICE1) 100 | with gr.Column(scale=5.5): 101 | result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")]).style(height=550) 102 | 103 | gr.Markdown(NOTES) 104 | 105 | print(gr.__version__) 106 | run_button.click(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text], 107 | outputs=[input_text, result_text]) 108 | input_text.submit(fn=request_model,inputs=[input_text, temperature, top_p, image_prompt, result_text], 109 | outputs=[input_text, result_text]) 110 | clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt]) 111 | image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) 112 | image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) 113 | 114 | print(gr.__version__) 115 | 116 | 117 | demo.queue(concurrency_count=10) 118 | demo.launch(share=args.share) 119 | 120 | 121 | if __name__ == '__main__': 122 | import argparse 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None) 125 | parser.add_argument("--share", action="store_true") 126 | args = parser.parse_args() 127 | 128 | main(args) -------------------------------------------------------------------------------- /web_demo_hf.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | import gradio as gr 3 | import mdtex2html 4 | 5 | tokenizer = AutoTokenizer.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True) 6 | model = AutoModel.from_pretrained("THUDM/visualglm-6b", trust_remote_code=True).half().cuda() 7 | model = model.eval() 8 | 9 | """Override Chatbot.postprocess""" 10 | 11 | 12 | def postprocess(self, y): 13 | if y is None: 14 | return [] 15 | for i, (message, response) in enumerate(y): 16 | y[i] = ( 17 | None if message is None else mdtex2html.convert((message)), 18 | None if response is None else mdtex2html.convert(response), 19 | ) 20 | return y 21 | 22 | 23 | gr.Chatbot.postprocess = postprocess 24 | 25 | 26 | def parse_text(text): 27 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 28 | lines = text.split("\n") 29 | lines = [line for line in lines if line != ""] 30 | count = 0 31 | for i, line in enumerate(lines): 32 | if "```" in line: 33 | count += 1 34 | items = line.split('`') 35 | if count % 2 == 1: 36 | lines[i] = f'
'
 37 |             else:
 38 |                 lines[i] = f'
' 39 | else: 40 | if i > 0: 41 | if count % 2 == 1: 42 | line = line.replace("`", "\`") 43 | line = line.replace("<", "<") 44 | line = line.replace(">", ">") 45 | line = line.replace(" ", " ") 46 | line = line.replace("*", "*") 47 | line = line.replace("_", "_") 48 | line = line.replace("-", "-") 49 | line = line.replace(".", ".") 50 | line = line.replace("!", "!") 51 | line = line.replace("(", "(") 52 | line = line.replace(")", ")") 53 | line = line.replace("$", "$") 54 | lines[i] = "
"+line 55 | text = "".join(lines) 56 | return text 57 | 58 | 59 | def predict(input, image_path, chatbot, max_length, top_p, temperature, history): 60 | if image_path is None: 61 | return [(input, "图片为空!请重新上传图片并重试。")] 62 | chatbot.append((parse_text(input), "")) 63 | for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length, top_p=top_p, 64 | temperature=temperature): 65 | chatbot[-1] = (parse_text(input), parse_text(response)) 66 | 67 | yield chatbot, history 68 | 69 | 70 | def predict_new_image(image_path, chatbot, max_length, top_p, temperature): 71 | input, history = "描述这张图片。", [] 72 | chatbot.append((parse_text(input), "")) 73 | for response, history in model.stream_chat(tokenizer, image_path, input, history, max_length=max_length, 74 | top_p=top_p, 75 | temperature=temperature): 76 | chatbot[-1] = (parse_text(input), parse_text(response)) 77 | 78 | yield chatbot, history 79 | 80 | 81 | def reset_user_input(): 82 | return gr.update(value='') 83 | 84 | 85 | def reset_state(): 86 | return None, [], [] 87 | 88 | 89 | with gr.Blocks() as demo: 90 | gr.HTML("""

VisualGLM

""") 91 | 92 | image_path = gr.Image(type="filepath", label="Image Prompt", value=None) 93 | chatbot = gr.Chatbot() 94 | with gr.Row(): 95 | with gr.Column(scale=4): 96 | with gr.Column(scale=12): 97 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( 98 | container=False) 99 | with gr.Column(min_width=32, scale=1): 100 | submitBtn = gr.Button("Submit", variant="primary") 101 | with gr.Column(scale=1): 102 | emptyBtn = gr.Button("Clear History") 103 | max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) 104 | top_p = gr.Slider(0, 1, value=0.4, step=0.01, label="Top P", interactive=True) 105 | temperature = gr.Slider(0, 1, value=0.8, step=0.01, label="Temperature", interactive=True) 106 | 107 | history = gr.State([]) 108 | 109 | submitBtn.click(predict, [user_input, image_path, chatbot, max_length, top_p, temperature, history], [chatbot, history], 110 | show_progress=True) 111 | 112 | image_path.upload(predict_new_image, [image_path, chatbot, max_length, top_p, temperature], [chatbot, history], 113 | show_progress=True) 114 | image_path.clear(reset_state, outputs=[image_path, chatbot, history], show_progress=True) 115 | 116 | submitBtn.click(reset_user_input, [], [user_input]) 117 | 118 | emptyBtn.click(reset_state, outputs=[image_path, chatbot, history], show_progress=True) 119 | 120 | demo.queue().launch(share=False, inbrowser=True) --------------------------------------------------------------------------------