├── assets └── modelscope_logo.png ├── infer_wqx_vl.py ├── README.md ├── web_ui_wqx_vl.py └── web_ui_wqx.py /assets/modelscope_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/InternLM/InternLM-WQX/HEAD/assets/modelscope_logo.png -------------------------------------------------------------------------------- /infer_wqx_vl.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import numpy as np 4 | import requests 5 | import torchvision.transforms as transforms 6 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 7 | import torch 8 | 9 | def padding_336(b): 10 | width, height = b.size 11 | tar = int(np.ceil(height / 560) * 560) 12 | top_padding = int((tar - height) / 2) 13 | bottom_padding = tar - height - top_padding 14 | left_padding = 0 15 | right_padding = 0 16 | b = transforms.functional.pad(b, [left_padding, top_padding, right_padding, bottom_padding], fill=[255, 255, 255]) 17 | 18 | return b 19 | 20 | 21 | def HD_transform(img, hd_num=25): 22 | width, height = img.size 23 | trans = False 24 | if width < height: 25 | img = img.transpose(Image.TRANSPOSE) 26 | trans = True 27 | width, height = img.size 28 | ratio = (width / height) 29 | scale = 1 30 | while scale * np.ceil(scale / ratio) <= hd_num: 31 | scale += 1 32 | scale -= 1 33 | new_w = int(scale * 560) 34 | new_h = int(new_w / ratio) 35 | 36 | img = transforms.functional.resize(img, [new_h, new_w], ) 37 | img = padding_336(img) 38 | width, height = img.size 39 | if trans: 40 | img = img.transpose(Image.TRANSPOSE) 41 | 42 | return img 43 | 44 | def process_query_and_image(query, image, model, HD_transform): 45 | def process_image(img): 46 | img = img.convert("RGB") 47 | img = HD_transform(img, hd_num=4) 48 | img = model.vis_processor(img).unsqueeze(0).cuda().half() 49 | return model.encode_img(img) 50 | 51 | embeds = [] 52 | im_mask = [] 53 | images_loc = [0] 54 | 55 | for i, pts in enumerate(images_loc + [len(query)]): 56 | subtext = query[0:pts] 57 | text_embeds = model.encode_text(subtext, add_special_tokens=True) 58 | embeds.append(text_embeds) 59 | im_mask.append(torch.zeros(text_embeds.shape[:2]).cuda()) 60 | 61 | if i == 0: 62 | image_embeds = process_image(image) 63 | embeds.append(image_embeds) 64 | im_mask.append(torch.ones(image_embeds.shape[:2]).cuda()) 65 | 66 | embeds = torch.cat(embeds, dim=1) 67 | im_mask = torch.cat(im_mask, dim=1).bool() 68 | 69 | return embeds, im_mask 70 | if __name__ == "__main__": 71 | model_path = "internlm/internlm2-wqx-vl-20b" 72 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 73 | model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda().eval() 74 | model.cuda().half() 75 | model.tokenizer = tokenizer 76 | 77 | image_url = "https://ks-1302698447.cos.ap-shanghai.myqcloud.com/img/phymerge.png" 78 | query = "体育课上两位同学在室内羽毛球场进行羽毛球比赛,羽毛球在空中上升的运动轨迹如图中虚线所示,考虑空气阻力,羽毛球加速度方向示意图可能正确的是(\u3000\u3000) \nA: \nB: \nC: \nD: " 79 | 80 | response = requests.get(image_url) 81 | image = Image.open(BytesIO(response.content)) 82 | embeds, im_mask = process_query_and_image(query, image, model, HD_transform) 83 | 84 | outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, 85 | temperature=0.0, max_new_tokens=256, num_beams=1, 86 | do_sample=False, repetition_penalty=1.0) 87 | output_token = outputs[0] 88 | output_text = model.tokenizer.decode(output_token, add_special_tokens=False) 89 | print(output_text) 90 | # 斜向下 91 | # 答案是:C 92 | 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |
5 |
6 | InternLM2-WQX 7 | 8 | 9 | HOT 10 | 11 | 12 |
13 |
14 | 15 | [![license](https://raw.githubusercontent.com/InternLM/InternLM/main/assets/license.svg)](./LICENSE) 16 | 17 | 18 | 19 | InternLM2-WQX-20B 🤗 | InternLM2-WQX-VL-20B 🤗 20 |
21 | 22 | # Introduction 23 | 24 | InternLM2-WQX与InternLM2-WQX-VL是InternLM团队于2024年高考前夕最新推出的文曲星系列模型。 25 | 26 | 高考覆盖各类学科及题型,同时因其开考前的“绝密性”,被视作中国最具权威的考试之一,成为评估考生综合能力的“试金石”。这一面向人类设计的高难度综合性测试,目前普遍被研究者用于考察大模型的智能水平。InternLM2-WQX系列模型在2024年高考评测集[GAOKAO-Eval](https://github.com/open-compass/GAOKAO-Eval)上取得了优异的成绩,综合表现与GPT-4o相当,且超越了国内外一系列开源大模型,体现了InternLM2-WQX系列模型优秀的性能。 27 | 28 | 我们即将更新关于文曲星系列模型数据准备的相关说明,敬请期待。 29 | 30 | 31 | # Model Zoo 32 | 33 | 34 | | Model | HuggingFace | ModelScope | Release Date | 35 | | --------------------------- | ----------------------------------------- | ---------------------------------------- | ------------ | 36 | | **InternLM2-WQX-20B** | [🤗internlm2-wqx-20b](https://huggingface.co/internlm/internlm2-wqx-20b) | [ internlm2-wqx-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-wqx-20b/summary) | 2024-06-04 | 37 | | **InternLM2-WQX-VL-20B** | [🤗internlm2-wqx-vl-20b](https://huggingface.co/internlm/internlm2-wqx-vl-20b) | [ internlm2-wqx-vl-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm2-wqx-vl-20b/summary) | 2024-06-04 | 38 | 39 | 40 | ## MD5 Check 41 | 42 | ### LLM权重文件的md5值 43 | ``` 44 | md5sum ./* 45 | 5209adfd6ef7d1724848ff0372362568 ./model-00001-of-00004.safetensors 46 | e37ee2eafecfed543d10dca75998204e ./model-00002-of-00004.safetensors 47 | ea3da8035b0c2a31c369dd463adf9b52 ./model-00003-of-00004.safetensors 48 | f1ff218f801c69fd4c12c534b64e1b60 ./model-00004-of-00004.safetensors 49 | ``` 50 | 51 | ### MLLM权重文件的md5值 52 | ``` 53 | md5sum ./* 54 | 158657dbae9bc369d67cf4bfbdfaaf71 ./pytorch_model-00001-of-00005.bin 55 | c21db8ac1315c10df768f6c3ae3f2825 ./pytorch_model-00002-of-00005.bin 56 | ebc4b0b70e8e9f1adc0b728558d650fb ./pytorch_model-00003-of-00005.bin 57 | eaa393a66dc632d0a6f0f7d815c439bb ./pytorch_model-00004-of-00005.bin 58 | 7e6e3237d99a7e8bd7ca9ba10747bfdb ./pytorch_model-00005-of-00005.bin 59 | 60 | ./clip_l_560_pro7b/* 61 | 97b05f40ee9826eda467489eed65f85c ./clip_l_560_pro7b/pytorch_model.bin 62 | ``` 63 | 64 | # Quick Start 65 | 66 | ### 快速调用**InternLM2-WQX-20B**语言模型 67 | 68 | 使用transformers 后端进行推理 69 | 70 | ```python 71 | import torch 72 | from transformers import AutoModelForCausalLM, AutoTokenizer 73 | 74 | device = "cuda" 75 | 76 | tokenizer = AutoTokenizer.from_pretrained("internlm/internlm2-wqx-20b", trust_remote_code=True) 77 | model = AutoModelForCausalLM.from_pretrained( 78 | "internlm/internlm2-wqx-20b", 79 | torch_dtype=torch.bfloat16, 80 | trust_remote_code=True 81 | ).to(device).eval() 82 | 83 | query = "已知圆柱和圆锥的底面半径相等,侧面积相等,且它们的高均为$ \\sqrt { 3 }$,则圆锥的体积为( ).\nA. $ 2 \\sqrt { 3 } \\pi$\nB. $ 3 \\sqrt { 3 } \\pi$\nC. $ 6 \\sqrt { 3 } \\pi$\nD. $ 9 \\sqrt { 3 } \\pi$" 84 | 85 | inputs = tokenizer(query, return_tensors="pt") 86 | 87 | inputs = inputs["input_ids"].to(device) 88 | 89 | gen_kwargs = {"max_length": 1024, "do_sample": False} 90 | 91 | outputs = model.generate(inputs, **gen_kwargs) 92 | outputs = outputs[0].cpu().tolist()[len(inputs[0]) :] 93 | 94 | response = tokenizer.decode(outputs, skip_special_tokens=True) 95 | print(response) 96 | ``` 97 | 98 | 使用vllm 后端进行推理: 99 | 100 | ```python 101 | from vllm import LLM, SamplingParams 102 | 103 | model_name = "internlm/internlm2-wqx-20b" 104 | prompts = ["已知圆柱和圆锥的底面半径相等,侧面积相等,且它们的高均为$ \\sqrt { 3 }$,则圆锥的体积为( ).\nA. $ 2 \\sqrt { 3 } \\pi$\nB. $ 3 \\sqrt { 3 } \\pi$\nC. $ 6 \\sqrt { 3 } \\pi$\nD. $ 9 \\sqrt { 3 } \\pi$"] 105 | sampling_params = SamplingParams(temperature=0.0, max_tokens=1024) 106 | 107 | llm = LLM( 108 | model=model_name, 109 | trust_remote_code=True, 110 | enforce_eager=True, 111 | ) 112 | 113 | outputs = llm.generate(prompts, sampling_params) 114 | 115 | for output in outputs: 116 | prompt = output.prompt 117 | generated_text = output.outputs[0].text 118 | print(f"Prompt: {prompt!r}, \nGenerated text: {generated_text!r}") 119 | ``` 120 | 121 | ### **InternLM2-WQX-20B**语言模型的 Web UI 122 | 123 | 使用transformers后端进行推理: 124 | 125 | ``` 126 | python web_ui_wqx.py -m internlm/internlm2-wqx-20b 127 | ``` 128 | 129 | ### 快速调用**InternLM2-WQX-VL-20B**视觉语言模型 130 | 131 | 使用transformers后端进行推理: 132 | 133 | ```python 134 | from PIL import Image 135 | from io import BytesIO 136 | import requests 137 | from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM 138 | import torch 139 | from infer_wqx_vl import process_query_and_image, HD_transform 140 | 141 | model_path = "internlm/internlm2-wqx-vl-20b" 142 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 143 | model = AutoModel.from_pretrained(model_path, torch_dtype=torch.bfloat16, trust_remote_code=True).cuda().eval() 144 | model.cuda().half() 145 | model.tokenizer = tokenizer 146 | 147 | image_url = "https://ks-1302698447.cos.ap-shanghai.myqcloud.com/img/phymerge.png" 148 | query = "体育课上两位同学在室内羽毛球场进行羽毛球比赛,羽毛球在空中上升的运动轨迹如图中虚线所示,考虑空气阻力,羽毛球加速度方向示意图可能正确的是(\u3000\u3000) \nA: \nB: \nC: \nD: " 149 | 150 | response = requests.get(image_url) 151 | image = Image.open(BytesIO(response.content)) 152 | embeds, im_mask = process_query_and_image(query, image, model, HD_transform) 153 | 154 | outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, 155 | temperature=0.0, max_new_tokens=256, num_beams=1, 156 | do_sample=False, repetition_penalty=1.0) 157 | output_token = outputs[0] 158 | output_text = model.tokenizer.decode(output_token, add_special_tokens=False) 159 | print(output_text) 160 | # 斜向下 161 | # 答案是:C 162 | ``` 163 | 针对这个选项里面有图片的考题,我们将图片进行了合并并标记上``来让语言模型能理解多图考题。 当前示例展示的是已经拼接好的图片,详细的图像预处理请参考[GAOKAO-Eval](https://github.com/open-compass/GAOKAO-Eval)中的多模态处理工具。 164 | 165 | ### **InternLM2-WQX-VL-20B**语言模型的 Web UI 166 | 167 | 使用transformers后端进行推理: 168 | 169 | ``` 170 | python web_ui_wqx_vl.py -m internlm/internlm2-wqx-vl-20b 171 | ``` 172 | 173 | # Citation 174 | 175 | ```bibtex 176 | @misc{2024internlm2wqx, 177 | title={https://github.com/InternLM/InternLM-WQX}, 178 | author={InternLM Team}, 179 | howpublished = {\url{https://github.com/InternLM/InternLM-WQX}}, 180 | year={2024} 181 | } 182 | ``` -------------------------------------------------------------------------------- /web_ui_wqx_vl.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import gradio as gr 4 | import mdtex2html 5 | import requests 6 | from PIL import Image 7 | from io import BytesIO 8 | import re 9 | import torch 10 | import queue 11 | import threading 12 | try: 13 | from transformers.generation.streamers import BaseStreamer 14 | except: # noqa # pylint: disable=bare-except 15 | BaseStreamer = None 16 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer, AutoModel 17 | from infer_wqx_vl import process_query_and_image, HD_transform 18 | 19 | 20 | DEFAULT_VL_CKPT_PATH = 'internlm/internlm2-wqx-vl-20b' 21 | 22 | def _get_args(): 23 | parser = ArgumentParser() 24 | parser.add_argument("-m", "--vl_checkpoint_path", type=str, default=DEFAULT_VL_CKPT_PATH, 25 | help="Checkpoint name or path, default to %(default)r") 26 | parser.add_argument("--share", action="store_true", default=False, 27 | help="Create a publicly shareable link for the interface.") 28 | parser.add_argument("--inbrowser", action="store_true", default=False, 29 | help="Automatically launch the interface in a new tab on the default browser.") 30 | parser.add_argument("--server-port", type=int, default=10086, 31 | help="Demo server port.") 32 | parser.add_argument("--server-name", type=str, default="0.0.0.0", 33 | help="Demo server name.") 34 | parser.add_argument("--cache_dir", default="data/img_cache", type=str, 35 | help="Directory to save image cache.") 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def load_vl_model_tokenizer(args): 42 | tokenizer = AutoTokenizer.from_pretrained( 43 | args.vl_checkpoint_path, trust_remote_code=True, resume_download=True, 44 | ) 45 | 46 | model = AutoModel.from_pretrained( 47 | args.vl_checkpoint_path, torch_dtype=torch.bfloat16, trust_remote_code=True, resume_download=True 48 | ).cuda().eval() 49 | model.cuda().half() 50 | model.tokenizer = tokenizer 51 | 52 | return model, tokenizer 53 | 54 | 55 | def load_model_tokenizer(args): 56 | tokenizer = AutoTokenizer.from_pretrained( 57 | args.checkpoint_path, trust_remote_code=True, resume_download=True, 58 | ) 59 | 60 | model = AutoModelForCausalLM.from_pretrained( 61 | args.checkpoint_path, 62 | trust_remote_code=True, 63 | resume_download=True, 64 | ).eval() 65 | 66 | return model, tokenizer 67 | 68 | 69 | def postprocess(self, y): 70 | """Override Chatbot.postprocess""" 71 | if y is None: 72 | return [] 73 | for i, (message, response) in enumerate(y): 74 | y[i] = ( 75 | None if message is None else mdtex2html.convert((message)), 76 | None if response is None else mdtex2html.convert(response), 77 | ) 78 | return y 79 | 80 | 81 | gr.Chatbot.postprocess = postprocess 82 | 83 | 84 | def parse_text(text): 85 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 86 | lines = text.split("\n") 87 | lines = [line for line in lines if line != ""] 88 | count = 0 89 | for i, line in enumerate(lines): 90 | if "```" in line: 91 | count += 1 92 | items = line.split('`') 93 | if count % 2 == 1: 94 | lines[i] = f'
'
 95 |             else:
 96 |                 lines[i] = f'
' 97 | else: 98 | if i > 0: 99 | if count % 2 == 1: 100 | line = line.replace("`", "\`") 101 | line = line.replace("<", "<") 102 | line = line.replace(">", ">") 103 | line = line.replace(" ", " ") 104 | line = line.replace("*", "*") 105 | line = line.replace("_", "_") 106 | line = line.replace("-", "-") 107 | line = line.replace(".", ".") 108 | line = line.replace("!", "!") 109 | line = line.replace("(", "(") 110 | line = line.replace(")", ")") 111 | line = line.replace("$", "$") 112 | lines[i] = "
" + line 113 | text = "".join(lines) 114 | return text 115 | 116 | 117 | def gc(): 118 | import gc 119 | gc.collect() 120 | if torch.cuda.is_available(): 121 | torch.cuda.empty_cache() 122 | 123 | def launch_demo(args, model): 124 | 125 | def predict(input, image_path, chatbot): 126 | 127 | if image_path is not None: 128 | input += '' 129 | if input == '' and image_path is None: 130 | return [(input, "文本与图片为空!请重试。")] 131 | chatbot.append((parse_text(input), "")) 132 | 133 | query = input 134 | 135 | if os.path.exists(image_path): 136 | image = Image.open(image_path) 137 | else: 138 | response = requests.get(image_path) 139 | image = Image.open(BytesIO(response.content)) 140 | with torch.cuda.amp.autocast(): 141 | embeds, im_mask = process_query_and_image(query, image, model, HD_transform) 142 | 143 | outputs = model.generate(inputs_embeds=embeds, im_mask=im_mask, 144 | temperature=0.0, max_new_tokens=256, num_beams=1, 145 | do_sample=False, repetition_penalty=1.0) 146 | output_token = outputs[0] 147 | output_text = model.tokenizer.decode(output_token, add_special_tokens=False) 148 | print(output_text,chatbot) 149 | output_text=output_text.replace("", "").replace("", "") 150 | 151 | chatbot[-1] = (parse_text(query), parse_text(output_text)) 152 | return chatbot 153 | 154 | def stop_generate(): 155 | global stop_gen 156 | stop_gen = True 157 | 158 | def reset_user_input(): 159 | return gr.update(value='') 160 | 161 | def reset_state(): 162 | stop_generate() 163 | gc() 164 | return None, [] 165 | 166 | examples = [ 167 | [r"体育课上两位同学在室内羽毛球场进行羽毛球比赛,羽毛球在空中上升的运动轨迹如图中虚线所示,考虑空气阻力,羽毛球加速度方向示意图可能正确的是(\u3000\u3000) \nA: \nB: \nC: \nD: ,对图片进行描述然后再回答", "https://ks-1302698447.cos.ap-shanghai.myqcloud.com/img/phymerge.png"] 168 | ] 169 | 170 | with gr.Blocks() as demo: 171 | gr.Markdown("""\ 172 |

""") 173 | gr.Markdown("""

InternLM2-WQX-VL
""") 174 | gr.Markdown( 175 | """\ 176 |
本WebUI基于InternLM2-WQX-VL打造,是InternLM团队推出的文曲星系列模型。
""") 177 | gr.Markdown("""\ 178 |
179 | InternLM2-WQX-20b 🤖 | 180 | 🤗  | 181 | InternLM2-WQX-VL-20b 🤖 | 182 | 🤗  | 183 |  💻 Github
""") 184 | 185 | chatbot = gr.Chatbot() 186 | with gr.Row(): 187 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10) 188 | image_path = gr.Image(type="filepath", label="Image Prompt", value=None) 189 | with gr.Row(): 190 | submit_btn = gr.Button("Submit(发送)", variant="primary") 191 | 192 | 193 | submit_btn.click(predict, [user_input, image_path, chatbot], [chatbot], 194 | show_progress=True) 195 | submit_btn.click(reset_user_input, [], [user_input]) 196 | 197 | image_path.clear(reset_state, outputs=[image_path, chatbot], show_progress=True) 198 | 199 | gr.Examples(examples=examples, inputs=[user_input, image_path]) 200 | 201 | 202 | demo.queue().launch( 203 | share=args.share, 204 | inbrowser=args.inbrowser, 205 | server_port=args.server_port, 206 | server_name=args.server_name, 207 | ) 208 | 209 | 210 | def main(): 211 | args = _get_args() 212 | 213 | model_vl, _ = load_vl_model_tokenizer(args) 214 | 215 | launch_demo(args, model_vl) 216 | 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /web_ui_wqx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | import gradio as gr 4 | import mdtex2html 5 | import queue 6 | import threading 7 | 8 | import torch 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | try: 12 | from transformers.generation.streamers import BaseStreamer 13 | except: # noqa # pylint: disable=bare-except 14 | BaseStreamer = None 15 | 16 | DEFAULT_MODEL_PATH = 'internlm/internlm2-wqx-20b' 17 | stop_gen = False 18 | 19 | def get_args(): 20 | parser = ArgumentParser() 21 | parser.add_argument("-m", "--model_path", type=str, default=DEFAULT_MODEL_PATH, 22 | help="Model name or path, default to %(default)r") 23 | parser.add_argument("--cpu_only", action="store_true", help="Run demo with CPU only") 24 | 25 | parser.add_argument("--share", action="store_true", default=False, 26 | help="Create a publicly shareable link for the interface.") 27 | parser.add_argument("--inbrowser", action="store_true", default=False, 28 | help="Automatically launch the interface in a new tab on the default browser.") 29 | parser.add_argument("--server_port", type=int, default=7860, 30 | help="Demo server port.") 31 | parser.add_argument("--server_name", type=str, default="127.0.0.1", 32 | help="Demo server name.") 33 | 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def load_model_tokenizer(args): 39 | tokenizer = AutoTokenizer.from_pretrained( 40 | args.model_path, trust_remote_code=True, resume_download=True, 41 | ) 42 | 43 | if args.cpu_only: 44 | device_map = "cpu" 45 | else: 46 | device_map = "auto" 47 | 48 | model = AutoModelForCausalLM.from_pretrained( 49 | args.model_path, 50 | device_map=device_map, 51 | trust_remote_code=True, 52 | resume_download=True, 53 | ).eval() 54 | 55 | return model, tokenizer 56 | 57 | 58 | def postprocess(self, y): 59 | """Override Chatbot.postprocess""" 60 | if y is None: 61 | return [] 62 | for i, (message, response) in enumerate(y): 63 | y[i] = ( 64 | None if message is None else mdtex2html.convert((message)), 65 | None if response is None else mdtex2html.convert(response), 66 | ) 67 | return y 68 | 69 | 70 | gr.Chatbot.postprocess = postprocess 71 | 72 | 73 | def parse_text(text): 74 | """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/""" 75 | lines = text.split("\n") 76 | lines = [line for line in lines if line != ""] 77 | count = 0 78 | for i, line in enumerate(lines): 79 | if "```" in line: 80 | count += 1 81 | items = line.split('`') 82 | if count % 2 == 1: 83 | lines[i] = f'
'
 84 |             else:
 85 |                 lines[i] = f'
' 86 | else: 87 | if i > 0: 88 | if count % 2 == 1: 89 | line = line.replace("`", "\`") 90 | line = line.replace("<", "<") 91 | line = line.replace(">", ">") 92 | line = line.replace(" ", " ") 93 | line = line.replace("*", "*") 94 | line = line.replace("_", "_") 95 | line = line.replace("-", "-") 96 | line = line.replace(".", ".") 97 | line = line.replace("!", "!") 98 | line = line.replace("(", "(") 99 | line = line.replace(")", ")") 100 | line = line.replace("$", "$") 101 | lines[i] = "
" + line 102 | text = "".join(lines) 103 | return text 104 | 105 | 106 | def gc(): 107 | import gc 108 | gc.collect() 109 | if torch.cuda.is_available(): 110 | torch.cuda.empty_cache() 111 | 112 | 113 | def stream_chat( 114 | model, 115 | tokenizer, 116 | query: str, 117 | max_new_tokens: int = 2048, 118 | do_sample: bool = True, 119 | temperature: float = 0.8, 120 | top_p: float = 0.8, 121 | **kwargs, 122 | ): 123 | 124 | if BaseStreamer is None: 125 | raise ModuleNotFoundError( 126 | "The version of `transformers` is too low. Please make sure " 127 | "that you have installed `transformers>=4.28.0`." 128 | ) 129 | response_queue = queue.Queue(maxsize=20) 130 | 131 | class ChatStreamer(BaseStreamer): 132 | def __init__(self, tokenizer) -> None: 133 | super().__init__() 134 | self.tokenizer = tokenizer 135 | self.queue = response_queue 136 | self.query = query 137 | self.response = "" 138 | self.received_inputs = False 139 | self.queue.put((self.query, self.response)) 140 | 141 | def put(self, value): 142 | if len(value.shape) > 1 and value.shape[0] > 1: 143 | raise ValueError("ChatStreamer only supports batch size 1") 144 | elif len(value.shape) > 1: 145 | value = value[0] 146 | 147 | if not self.received_inputs: 148 | # The first received value is input_ids, ignore here 149 | self.received_inputs = True 150 | return 151 | 152 | token = self.tokenizer.decode([value[-1]], skip_special_tokens=True) 153 | if token.strip() != "<|im_end|>": 154 | self.response = self.response + token 155 | response = (self.query, self.response) 156 | self.queue.put(response) 157 | 158 | def end(self): 159 | self.queue.put(None) 160 | 161 | def stream_producer(): 162 | inputs = tokenizer([query], return_tensors="pt") 163 | inputs = {k: v.to(next(model.parameters()).device) for k, v in inputs.items() if torch.is_tensor(v)} 164 | eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids(["<|im_end|>"])[0]] 165 | outputs = model.generate( 166 | **inputs, 167 | streamer=ChatStreamer(tokenizer=tokenizer), 168 | max_new_tokens=max_new_tokens, 169 | do_sample=do_sample, 170 | temperature=temperature, 171 | top_p=top_p, 172 | eos_token_id=eos_token_id, 173 | **kwargs, 174 | ) 175 | outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] 176 | response = tokenizer.decode(outputs, skip_special_tokens=True) 177 | response = response.split("<|im_end|>")[0] 178 | return response, [] 179 | 180 | def consumer(): 181 | producer = threading.Thread(target=stream_producer) 182 | producer.start() 183 | while True: 184 | res = response_queue.get() 185 | if res is None: 186 | return 187 | yield res 188 | 189 | return consumer() 190 | 191 | 192 | def launch_demo(args, model, tokenizer): 193 | 194 | def predict(input, chatbot): 195 | global stop_gen 196 | stop_gen = False 197 | chatbot.append((parse_text(input), "")) 198 | for query, response in stream_chat(model, tokenizer, input): 199 | if stop_gen: 200 | chatbot.clear() 201 | return chatbot 202 | chatbot[-1] = (parse_text(query), parse_text(response)) 203 | yield chatbot 204 | 205 | def stop_generate(): 206 | global stop_gen 207 | stop_gen = True 208 | 209 | def reset_user_input(): 210 | return gr.update(value='') 211 | 212 | def reset_state(): 213 | stop_generate() 214 | gc() 215 | return [] 216 | 217 | with gr.Blocks() as demo: 218 | gr.Markdown("""\ 219 |

""") 220 | gr.Markdown("""

InternLM2-WQX
""") 221 | gr.Markdown( 222 | """\ 223 |
本WebUI基于InternLM2-WQX打造,是InternLM团队推出的文曲星系列模型。
""") 224 | gr.Markdown("""\ 225 |
226 | InternLM2-WQX-20b 🤖 | 227 | 🤗  | 228 | InternLM2-WQX-VL-20b 🤖 | 229 | 🤗  | 230 |  💻 Github
""") 231 | 232 | chatbot = gr.Chatbot() 233 | with gr.Row(): 234 | user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10) 235 | with gr.Row(): 236 | submit_btn = gr.Button("Submit(发送)", variant="primary") 237 | clear_btn = gr.Button("Stop(停止生成)", variant="stop") 238 | 239 | submit_btn.click(predict, [user_input, chatbot], [chatbot], 240 | show_progress=True) 241 | submit_btn.click(reset_user_input, [], [user_input]) 242 | 243 | clear_btn.click(reset_state, outputs=[chatbot], show_progress=True) 244 | 245 | demo.queue().launch( 246 | share=args.share, 247 | inbrowser=args.inbrowser, 248 | server_port=args.server_port, 249 | server_name=args.server_name, 250 | ) 251 | 252 | 253 | def main(): 254 | args = get_args() 255 | 256 | model, tokenizer = load_model_tokenizer(args) 257 | 258 | launch_demo(args, model, tokenizer) 259 | 260 | 261 | if __name__ == '__main__': 262 | main() 263 | --------------------------------------------------------------------------------