├── README.md
├── ShareGPT4V_colab.ipynb
└── ShareGPT4V_8bit_colab.ipynb
/README.md:
--------------------------------------------------------------------------------
1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 |
5 | # 🚦 WIP 🚦
6 |
7 | ## 🦒 Colab
8 |
9 | | Colab | Info
10 | | --- | --- |
11 | [](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_colab.ipynb) | ShareGPT4V_colab (7B 16bit Pro Colab 😐)
12 | [](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_8bit_colab.ipynb) | ShareGPT4V_8bit_colab (7B 8bit Free Colab T4)
13 |
14 | ## 🦆 Kaggle
15 |
16 | | Kaggle | Info
17 | | --- | --- |
18 | [](https://kaggle.com/camenduru/sharegpt4v) | sharegpt4v_kaggle (7B 16bit Free Kaggle T4)
19 |
20 | ## Main Repo
21 | https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V
22 | https://github.com/haotian-liu/LLaVA
23 |
24 | ## Paper
25 | https://arxiv.org/abs/2311.12793
26 |
27 | ## Page
28 | https://sharegpt4v.github.io/
29 |
30 | ## Output
31 | 
32 |
33 |
--------------------------------------------------------------------------------
/ShareGPT4V_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone -b dev https://github.com/camenduru/InternLM-XComposer\n",
22 | "!pip install -q https://github.com/camenduru/wheels/releases/download/colab/llava-ShareGPT4V-1.1.3-py3-none-any.whl gradio\n",
23 | "%cd /content/InternLM-XComposer/projects/ShareGPT4V\n",
24 | "\n",
25 | "import hashlib\n",
26 | "import json\n",
27 | "import os\n",
28 | "import time\n",
29 | "from threading import Thread\n",
30 | "\n",
31 | "import gradio as gr\n",
32 | "import torch\n",
33 | "from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)\n",
34 | "from llava.conversation import (SeparatorStyle, conv_templates, default_conversation)\n",
35 | "from llava.mm_utils import (KeywordsStoppingCriteria, load_image_from_base64, process_images, tokenizer_image_token)\n",
36 | "from llava.model.builder import load_pretrained_model\n",
37 | "from transformers import TextIteratorStreamer\n",
38 | "\n",
39 | "print(gr.__version__)\n",
40 | "\n",
41 | "block_css = \"\"\"\n",
42 | "\n",
43 | "#buttons button {\n",
44 | " min-width: min(120px,100%);\n",
45 | "}\n",
46 | "\"\"\"\n",
47 | "title_markdown = (\"\"\"\n",
48 | "# 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions\n",
49 | "### 🔊 Notice: The demo of Share-Captioner will soon be supported. Stay tune for updates!\n",
50 | "[[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V)] | 📚 [[Paper](https://arxiv.org/pdf/2311.12793.pdf)]\n",
51 | "\"\"\")\n",
52 | "tos_markdown = (\"\"\"\n",
53 | "### Terms of use\n",
54 | "By using this service, users are required to agree to the following terms:\n",
55 | "The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\n",
56 | "For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n",
57 | "\"\"\")\n",
58 | "learn_more_markdown = (\"\"\"\n",
59 | "### License\n",
60 | "The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n",
61 | "\"\"\")\n",
62 | "ack_markdown = (\"\"\"\n",
63 | "### Acknowledgement\n",
64 | "The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!\n",
65 | "\"\"\")\n",
66 | "\n",
67 | "\n",
68 | "def regenerate(state, image_process_mode):\n",
69 | " state.messages[-1][-1] = None\n",
70 | " prev_human_msg = state.messages[-2]\n",
71 | " if type(prev_human_msg[1]) in (tuple, list):\n",
72 | " prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)\n",
73 | " state.skip_next = False\n",
74 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
75 | "\n",
76 | "\n",
77 | "def clear_history():\n",
78 | " state = default_conversation.copy()\n",
79 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
80 | "\n",
81 | "\n",
82 | "def add_text(state, text, image, image_process_mode):\n",
83 | " if len(text) <= 0 and image is None:\n",
84 | " state.skip_next = True\n",
85 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
86 | "\n",
87 | " text = text[:1536] # Hard cut-off\n",
88 | " if image is not None:\n",
89 | " text = text[:1200] # Hard cut-off for images\n",
90 | " if '' not in text:\n",
91 | " # text = '' + text\n",
92 | " text = text + '\\n'\n",
93 | " text = (text, image, image_process_mode)\n",
94 | " if len(state.get_images(return_pil=True)) > 0:\n",
95 | " state = default_conversation.copy()\n",
96 | " state.append_message(state.roles[0], text)\n",
97 | " state.append_message(state.roles[1], None)\n",
98 | " state.skip_next = False\n",
99 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
100 | "\n",
101 | "\n",
102 | "def load_demo():\n",
103 | " state = default_conversation.copy()\n",
104 | " return state\n",
105 | "\n",
106 | "@torch.inference_mode()\n",
107 | "def get_response(params):\n",
108 | " prompt = params[\"prompt\"]\n",
109 | " ori_prompt = prompt\n",
110 | " images = params.get(\"images\", None)\n",
111 | " num_image_tokens = 0\n",
112 | " if images is not None and len(images) > 0:\n",
113 | " if len(images) > 0:\n",
114 | " if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):\n",
115 | " raise ValueError(\n",
116 | " \"Number of images does not match number of tokens in prompt\")\n",
117 | "\n",
118 | " images = [load_image_from_base64(image) for image in images]\n",
119 | " images = process_images(images, image_processor, model.config)\n",
120 | "\n",
121 | " if type(images) is list:\n",
122 | " images = [image.to(model.device, dtype=torch.float16)\n",
123 | " for image in images]\n",
124 | " else:\n",
125 | " images = images.to(model.device, dtype=torch.float16)\n",
126 | "\n",
127 | " replace_token = DEFAULT_IMAGE_TOKEN\n",
128 | " if getattr(model.config, 'mm_use_im_start_end', False):\n",
129 | " replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n",
130 | " prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)\n",
131 | "\n",
132 | " num_image_tokens = prompt.count(\n",
133 | " replace_token) * model.get_vision_tower().num_patches\n",
134 | " else:\n",
135 | " images = None\n",
136 | " image_args = {\"images\": images}\n",
137 | " else:\n",
138 | " images = None\n",
139 | " image_args = {}\n",
140 | "\n",
141 | " temperature = float(params.get(\"temperature\", 1.0))\n",
142 | " top_p = float(params.get(\"top_p\", 1.0))\n",
143 | " max_context_length = getattr(\n",
144 | " model.config, 'max_position_embeddings', 2048)\n",
145 | " max_new_tokens = min(int(params.get(\"max_new_tokens\", 256)), 1024)\n",
146 | " stop_str = params.get(\"stop\", None)\n",
147 | " do_sample = True if temperature > 0.001 else False\n",
148 | "\n",
149 | " input_ids = tokenizer_image_token(\n",
150 | " prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)\n",
151 | " keywords = [stop_str]\n",
152 | " stopping_criteria = KeywordsStoppingCriteria(\n",
153 | " keywords, tokenizer, input_ids)\n",
154 | " streamer = TextIteratorStreamer(\n",
155 | " tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)\n",
156 | "\n",
157 | " max_new_tokens = min(max_new_tokens, max_context_length -\n",
158 | " input_ids.shape[-1] - num_image_tokens)\n",
159 | "\n",
160 | " if max_new_tokens < 1:\n",
161 | " yield json.dumps({\"text\": ori_prompt + \"Exceeds max token length. Please start a new conversation, thanks.\", \"error_code\": 0}).encode() + b\"\\0\"\n",
162 | " return\n",
163 | "\n",
164 | " # local inference\n",
165 | " thread = Thread(target=model.generate, kwargs=dict(\n",
166 | " inputs=input_ids,\n",
167 | " do_sample=do_sample,\n",
168 | " temperature=temperature,\n",
169 | " top_p=top_p,\n",
170 | " max_new_tokens=max_new_tokens,\n",
171 | " streamer=streamer,\n",
172 | " stopping_criteria=[stopping_criteria],\n",
173 | " use_cache=True,\n",
174 | " **image_args\n",
175 | " ))\n",
176 | " thread.start()\n",
177 | "\n",
178 | " generated_text = ori_prompt\n",
179 | " for new_text in streamer:\n",
180 | " generated_text += new_text\n",
181 | " if generated_text.endswith(stop_str):\n",
182 | " generated_text = generated_text[:-len(stop_str)]\n",
183 | " yield json.dumps({\"text\": generated_text, \"error_code\": 0}).encode()\n",
184 | "\n",
185 | "\n",
186 | "def http_bot(state, temperature, top_p, max_new_tokens):\n",
187 | " if state.skip_next:\n",
188 | " # This generate call is skipped due to invalid inputs\n",
189 | " yield (state, state.to_gradio_chatbot())\n",
190 | " return\n",
191 | "\n",
192 | " if len(state.messages) == state.offset + 2:\n",
193 | " # First round of conversation\n",
194 | " if \"llava\" in model_name.lower():\n",
195 | " if 'llama-2' in model_name.lower():\n",
196 | " template_name = \"llava_llama_2\"\n",
197 | " elif \"v1\" in model_name.lower():\n",
198 | " if 'mmtag' in model_name.lower():\n",
199 | " template_name = \"v1_mmtag\"\n",
200 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n",
201 | " template_name = \"v1_mmtag\"\n",
202 | " else:\n",
203 | " template_name = \"llava_v1\"\n",
204 | " elif \"mpt\" in model_name.lower():\n",
205 | " template_name = \"mpt\"\n",
206 | " else:\n",
207 | " if 'mmtag' in model_name.lower():\n",
208 | " template_name = \"v0_mmtag\"\n",
209 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n",
210 | " template_name = \"v0_mmtag\"\n",
211 | " else:\n",
212 | " template_name = \"llava_v0\"\n",
213 | " elif \"mpt\" in model_name:\n",
214 | " template_name = \"mpt_text\"\n",
215 | " elif \"llama-2\" in model_name:\n",
216 | " template_name = \"llama_2\"\n",
217 | " else:\n",
218 | " template_name = \"vicuna_v1\"\n",
219 | " new_state = conv_templates[template_name].copy()\n",
220 | " new_state.append_message(new_state.roles[0], state.messages[-2][1])\n",
221 | " new_state.append_message(new_state.roles[1], None)\n",
222 | " state = new_state\n",
223 | "\n",
224 | " # Construct prompt\n",
225 | " prompt = state.get_prompt()\n",
226 | "\n",
227 | " all_images = state.get_images(return_pil=True)\n",
228 | " all_image_hash = [hashlib.md5(image.tobytes()).hexdigest()\n",
229 | " for image in all_images]\n",
230 | "\n",
231 | " # Make requests\n",
232 | " pload = {\n",
233 | " \"model\": model_name,\n",
234 | " \"prompt\": prompt,\n",
235 | " \"temperature\": float(temperature),\n",
236 | " \"top_p\": float(top_p),\n",
237 | " \"max_new_tokens\": min(int(max_new_tokens), 1536),\n",
238 | " \"stop\": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,\n",
239 | " \"images\": f'List of {len(state.get_images())} images: {all_image_hash}',\n",
240 | " }\n",
241 | "\n",
242 | " pload['images'] = state.get_images()\n",
243 | "\n",
244 | " state.messages[-1][-1] = \"▌\"\n",
245 | " yield (state, state.to_gradio_chatbot())\n",
246 | "\n",
247 | " # for stream\n",
248 | " output = get_response(pload)\n",
249 | " for chunk in output:\n",
250 | " if chunk:\n",
251 | " data = json.loads(chunk.decode())\n",
252 | " if data[\"error_code\"] == 0:\n",
253 | " output = data[\"text\"][len(prompt):].strip()\n",
254 | " state.messages[-1][-1] = output + \"▌\"\n",
255 | " yield (state, state.to_gradio_chatbot())\n",
256 | " else:\n",
257 | " output = data[\"text\"] + \\\n",
258 | " f\" (error_code: {data['error_code']})\"\n",
259 | " state.messages[-1][-1] = output\n",
260 | " yield (state, state.to_gradio_chatbot())\n",
261 | " return\n",
262 | " time.sleep(0.03)\n",
263 | "\n",
264 | " state.messages[-1][-1] = state.messages[-1][-1][:-1]\n",
265 | " yield (state, state.to_gradio_chatbot())\n",
266 | "\n",
267 | "\n",
268 | "def build_demo():\n",
269 | " textbox = gr.Textbox(\n",
270 | " show_label=False, placeholder=\"Enter text and press ENTER\", container=False)\n",
271 | " with gr.Blocks(title=\"ShareGPT4V\", theme=gr.themes.Default(), css=block_css) as demo:\n",
272 | " state = gr.State()\n",
273 | " gr.Markdown(title_markdown)\n",
274 | "\n",
275 | " with gr.Row():\n",
276 | " with gr.Column(scale=5):\n",
277 | " with gr.Row(elem_id=\"Model ID\"):\n",
278 | " gr.Dropdown(\n",
279 | " choices=['ShareGPT4V-7B'],\n",
280 | " value='ShareGPT4V-7B',\n",
281 | " interactive=True,\n",
282 | " label='Model ID',\n",
283 | " container=False)\n",
284 | " imagebox = gr.Image(type=\"pil\")\n",
285 | " image_process_mode = gr.Radio(\n",
286 | " [\"Crop\", \"Resize\", \"Pad\", \"Default\"],\n",
287 | " value=\"Default\",\n",
288 | " label=\"Preprocess for non-square image\", visible=False)\n",
289 | "\n",
290 | " cur_dir = \"/content/InternLM-XComposer/projects/ShareGPT4V\"\n",
291 | " gr.Examples(examples=[\n",
292 | " [f\"{cur_dir}/examples/breaking_bad.png\",\n",
293 | " \"What is the most common catchphrase of the character on the right?\"],\n",
294 | " [f\"{cur_dir}/examples/photo.png\",\n",
295 | " \"From a photography perspective, analyze what makes this picture beautiful?\"],\n",
296 | " ], inputs=[imagebox, textbox])\n",
297 | "\n",
298 | " with gr.Accordion(\"Parameters\", open=False) as _:\n",
299 | " temperature = gr.Slider(\n",
300 | " minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label=\"Temperature\",)\n",
301 | " top_p = gr.Slider(\n",
302 | " minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label=\"Top P\",)\n",
303 | " max_output_tokens = gr.Slider(\n",
304 | " minimum=0, maximum=1024, value=512, step=64, interactive=True, label=\"Max output tokens\",)\n",
305 | "\n",
306 | " with gr.Column(scale=8):\n",
307 | " chatbot = gr.Chatbot(\n",
308 | " elem_id=\"chatbot\", label=\"ShareGPT4V Chatbot\", height=550)\n",
309 | " with gr.Row():\n",
310 | " with gr.Column(scale=8):\n",
311 | " textbox.render()\n",
312 | " with gr.Column(scale=1, min_width=50):\n",
313 | " submit_btn = gr.Button(value=\"Send\", variant=\"primary\")\n",
314 | " with gr.Row(elem_id=\"buttons\") as _:\n",
315 | " regenerate_btn = gr.Button(\n",
316 | " value=\"🔄 Regenerate\", interactive=True)\n",
317 | " clear_btn = gr.Button(value=\"🗑️ Clear\", interactive=True)\n",
318 | "\n",
319 | " gr.Markdown(tos_markdown)\n",
320 | " gr.Markdown(learn_more_markdown)\n",
321 | " gr.Markdown(ack_markdown)\n",
322 | "\n",
323 | " regenerate_btn.click(\n",
324 | " regenerate,\n",
325 | " [state, image_process_mode],\n",
326 | " [state, chatbot, textbox, imagebox],\n",
327 | " queue=False\n",
328 | " ).then(\n",
329 | " http_bot,\n",
330 | " [state, temperature, top_p, max_output_tokens],\n",
331 | " [state, chatbot]\n",
332 | " )\n",
333 | "\n",
334 | " clear_btn.click(\n",
335 | " clear_history,\n",
336 | " None,\n",
337 | " [state, chatbot, textbox, imagebox],\n",
338 | " queue=False\n",
339 | " )\n",
340 | "\n",
341 | " textbox.submit(\n",
342 | " add_text,\n",
343 | " [state, textbox, imagebox, image_process_mode],\n",
344 | " [state, chatbot, textbox, imagebox],\n",
345 | " queue=False\n",
346 | " ).then(\n",
347 | " http_bot,\n",
348 | " [state, temperature, top_p, max_output_tokens],\n",
349 | " [state, chatbot]\n",
350 | " )\n",
351 | "\n",
352 | " submit_btn.click(\n",
353 | " add_text,\n",
354 | " [state, textbox, imagebox, image_process_mode],\n",
355 | " [state, chatbot, textbox, imagebox],\n",
356 | " queue=False\n",
357 | " ).then(\n",
358 | " http_bot,\n",
359 | " [state, temperature, top_p, max_output_tokens],\n",
360 | " [state, chatbot]\n",
361 | " )\n",
362 | "\n",
363 | " demo.load(\n",
364 | " load_demo,\n",
365 | " None,\n",
366 | " [state],\n",
367 | " queue=False\n",
368 | " )\n",
369 | " return demo\n",
370 | "\n",
371 | "model_name = \"llava-v1.5-7b\"\n",
372 | "tokenizer, model, image_processor, context_len = load_pretrained_model(\"4bit/ShareGPT4V-7B-5GB\", None, \"llava-v1.5-7b\", False, False)\n",
373 | "demo = build_demo()\n",
374 | "demo.queue()\n",
375 | "demo.launch(share=True, inline=False, debug=True)"
376 | ]
377 | }
378 | ],
379 | "metadata": {
380 | "accelerator": "GPU",
381 | "colab": {
382 | "gpuType": "T4",
383 | "provenance": []
384 | },
385 | "kernelspec": {
386 | "display_name": "Python 3",
387 | "name": "python3"
388 | },
389 | "language_info": {
390 | "name": "python"
391 | }
392 | },
393 | "nbformat": 4,
394 | "nbformat_minor": 0
395 | }
396 |
--------------------------------------------------------------------------------
/ShareGPT4V_8bit_colab.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {
6 | "id": "view-in-github"
7 | },
8 | "source": [
9 | "[](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_8bit_colab.ipynb)"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": null,
15 | "metadata": {
16 | "id": "VjYy0F2gZIPR"
17 | },
18 | "outputs": [],
19 | "source": [
20 | "%cd /content\n",
21 | "!git clone -b dev https://github.com/camenduru/InternLM-XComposer\n",
22 | "!pip install -q https://github.com/camenduru/wheels/releases/download/colab/llava-ShareGPT4V-1.1.3-py3-none-any.whl gradio\n",
23 | "%cd /content/InternLM-XComposer/projects/ShareGPT4V\n",
24 | "\n",
25 | "import hashlib\n",
26 | "import json\n",
27 | "import os\n",
28 | "import time\n",
29 | "from threading import Thread\n",
30 | "\n",
31 | "import gradio as gr\n",
32 | "import torch\n",
33 | "from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)\n",
34 | "from llava.conversation import (SeparatorStyle, conv_templates, default_conversation)\n",
35 | "from llava.mm_utils import (KeywordsStoppingCriteria, load_image_from_base64, process_images, tokenizer_image_token)\n",
36 | "from llava.model.builder import load_pretrained_model\n",
37 | "from transformers import TextIteratorStreamer\n",
38 | "\n",
39 | "print(gr.__version__)\n",
40 | "\n",
41 | "block_css = \"\"\"\n",
42 | "\n",
43 | "#buttons button {\n",
44 | " min-width: min(120px,100%);\n",
45 | "}\n",
46 | "\"\"\"\n",
47 | "title_markdown = (\"\"\"\n",
48 | "# 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions\n",
49 | "### 🔊 Notice: The demo of Share-Captioner will soon be supported. Stay tune for updates!\n",
50 | "[[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V)] | 📚 [[Paper](https://arxiv.org/pdf/2311.12793.pdf)]\n",
51 | "\"\"\")\n",
52 | "tos_markdown = (\"\"\"\n",
53 | "### Terms of use\n",
54 | "By using this service, users are required to agree to the following terms:\n",
55 | "The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\n",
56 | "For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n",
57 | "\"\"\")\n",
58 | "learn_more_markdown = (\"\"\"\n",
59 | "### License\n",
60 | "The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n",
61 | "\"\"\")\n",
62 | "ack_markdown = (\"\"\"\n",
63 | "### Acknowledgement\n",
64 | "The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!\n",
65 | "\"\"\")\n",
66 | "\n",
67 | "\n",
68 | "def regenerate(state, image_process_mode):\n",
69 | " state.messages[-1][-1] = None\n",
70 | " prev_human_msg = state.messages[-2]\n",
71 | " if type(prev_human_msg[1]) in (tuple, list):\n",
72 | " prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)\n",
73 | " state.skip_next = False\n",
74 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
75 | "\n",
76 | "\n",
77 | "def clear_history():\n",
78 | " state = default_conversation.copy()\n",
79 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
80 | "\n",
81 | "\n",
82 | "def add_text(state, text, image, image_process_mode):\n",
83 | " if len(text) <= 0 and image is None:\n",
84 | " state.skip_next = True\n",
85 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
86 | "\n",
87 | " text = text[:1536] # Hard cut-off\n",
88 | " if image is not None:\n",
89 | " text = text[:1200] # Hard cut-off for images\n",
90 | " if '' not in text:\n",
91 | " # text = '' + text\n",
92 | " text = text + '\\n'\n",
93 | " text = (text, image, image_process_mode)\n",
94 | " if len(state.get_images(return_pil=True)) > 0:\n",
95 | " state = default_conversation.copy()\n",
96 | " state.append_message(state.roles[0], text)\n",
97 | " state.append_message(state.roles[1], None)\n",
98 | " state.skip_next = False\n",
99 | " return (state, state.to_gradio_chatbot(), \"\", None)\n",
100 | "\n",
101 | "\n",
102 | "def load_demo():\n",
103 | " state = default_conversation.copy()\n",
104 | " return state\n",
105 | "\n",
106 | "@torch.inference_mode()\n",
107 | "def get_response(params):\n",
108 | " prompt = params[\"prompt\"]\n",
109 | " ori_prompt = prompt\n",
110 | " images = params.get(\"images\", None)\n",
111 | " num_image_tokens = 0\n",
112 | " if images is not None and len(images) > 0:\n",
113 | " if len(images) > 0:\n",
114 | " if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):\n",
115 | " raise ValueError(\n",
116 | " \"Number of images does not match number of tokens in prompt\")\n",
117 | "\n",
118 | " images = [load_image_from_base64(image) for image in images]\n",
119 | " images = process_images(images, image_processor, model.config)\n",
120 | "\n",
121 | " if type(images) is list:\n",
122 | " images = [image.to(model.device, dtype=torch.float16)\n",
123 | " for image in images]\n",
124 | " else:\n",
125 | " images = images.to(model.device, dtype=torch.float16)\n",
126 | "\n",
127 | " replace_token = DEFAULT_IMAGE_TOKEN\n",
128 | " if getattr(model.config, 'mm_use_im_start_end', False):\n",
129 | " replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n",
130 | " prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)\n",
131 | "\n",
132 | " num_image_tokens = prompt.count(\n",
133 | " replace_token) * model.get_vision_tower().num_patches\n",
134 | " else:\n",
135 | " images = None\n",
136 | " image_args = {\"images\": images}\n",
137 | " else:\n",
138 | " images = None\n",
139 | " image_args = {}\n",
140 | "\n",
141 | " temperature = float(params.get(\"temperature\", 1.0))\n",
142 | " top_p = float(params.get(\"top_p\", 1.0))\n",
143 | " max_context_length = getattr(\n",
144 | " model.config, 'max_position_embeddings', 2048)\n",
145 | " max_new_tokens = min(int(params.get(\"max_new_tokens\", 256)), 1024)\n",
146 | " stop_str = params.get(\"stop\", None)\n",
147 | " do_sample = True if temperature > 0.001 else False\n",
148 | "\n",
149 | " input_ids = tokenizer_image_token(\n",
150 | " prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)\n",
151 | " keywords = [stop_str]\n",
152 | " stopping_criteria = KeywordsStoppingCriteria(\n",
153 | " keywords, tokenizer, input_ids)\n",
154 | " streamer = TextIteratorStreamer(\n",
155 | " tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)\n",
156 | "\n",
157 | " max_new_tokens = min(max_new_tokens, max_context_length -\n",
158 | " input_ids.shape[-1] - num_image_tokens)\n",
159 | "\n",
160 | " if max_new_tokens < 1:\n",
161 | " yield json.dumps({\"text\": ori_prompt + \"Exceeds max token length. Please start a new conversation, thanks.\", \"error_code\": 0}).encode() + b\"\\0\"\n",
162 | " return\n",
163 | "\n",
164 | " # local inference\n",
165 | " thread = Thread(target=model.generate, kwargs=dict(\n",
166 | " inputs=input_ids,\n",
167 | " do_sample=do_sample,\n",
168 | " temperature=temperature,\n",
169 | " top_p=top_p,\n",
170 | " max_new_tokens=max_new_tokens,\n",
171 | " streamer=streamer,\n",
172 | " stopping_criteria=[stopping_criteria],\n",
173 | " use_cache=True,\n",
174 | " **image_args\n",
175 | " ))\n",
176 | " thread.start()\n",
177 | "\n",
178 | " generated_text = ori_prompt\n",
179 | " for new_text in streamer:\n",
180 | " generated_text += new_text\n",
181 | " if generated_text.endswith(stop_str):\n",
182 | " generated_text = generated_text[:-len(stop_str)]\n",
183 | " yield json.dumps({\"text\": generated_text, \"error_code\": 0}).encode()\n",
184 | "\n",
185 | "\n",
186 | "def http_bot(state, temperature, top_p, max_new_tokens):\n",
187 | " if state.skip_next:\n",
188 | " # This generate call is skipped due to invalid inputs\n",
189 | " yield (state, state.to_gradio_chatbot())\n",
190 | " return\n",
191 | "\n",
192 | " if len(state.messages) == state.offset + 2:\n",
193 | " # First round of conversation\n",
194 | " if \"llava\" in model_name.lower():\n",
195 | " if 'llama-2' in model_name.lower():\n",
196 | " template_name = \"llava_llama_2\"\n",
197 | " elif \"v1\" in model_name.lower():\n",
198 | " if 'mmtag' in model_name.lower():\n",
199 | " template_name = \"v1_mmtag\"\n",
200 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n",
201 | " template_name = \"v1_mmtag\"\n",
202 | " else:\n",
203 | " template_name = \"llava_v1\"\n",
204 | " elif \"mpt\" in model_name.lower():\n",
205 | " template_name = \"mpt\"\n",
206 | " else:\n",
207 | " if 'mmtag' in model_name.lower():\n",
208 | " template_name = \"v0_mmtag\"\n",
209 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n",
210 | " template_name = \"v0_mmtag\"\n",
211 | " else:\n",
212 | " template_name = \"llava_v0\"\n",
213 | " elif \"mpt\" in model_name:\n",
214 | " template_name = \"mpt_text\"\n",
215 | " elif \"llama-2\" in model_name:\n",
216 | " template_name = \"llama_2\"\n",
217 | " else:\n",
218 | " template_name = \"vicuna_v1\"\n",
219 | " new_state = conv_templates[template_name].copy()\n",
220 | " new_state.append_message(new_state.roles[0], state.messages[-2][1])\n",
221 | " new_state.append_message(new_state.roles[1], None)\n",
222 | " state = new_state\n",
223 | "\n",
224 | " # Construct prompt\n",
225 | " prompt = state.get_prompt()\n",
226 | "\n",
227 | " all_images = state.get_images(return_pil=True)\n",
228 | " all_image_hash = [hashlib.md5(image.tobytes()).hexdigest()\n",
229 | " for image in all_images]\n",
230 | "\n",
231 | " # Make requests\n",
232 | " pload = {\n",
233 | " \"model\": model_name,\n",
234 | " \"prompt\": prompt,\n",
235 | " \"temperature\": float(temperature),\n",
236 | " \"top_p\": float(top_p),\n",
237 | " \"max_new_tokens\": min(int(max_new_tokens), 1536),\n",
238 | " \"stop\": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,\n",
239 | " \"images\": f'List of {len(state.get_images())} images: {all_image_hash}',\n",
240 | " }\n",
241 | "\n",
242 | " pload['images'] = state.get_images()\n",
243 | "\n",
244 | " state.messages[-1][-1] = \"▌\"\n",
245 | " yield (state, state.to_gradio_chatbot())\n",
246 | "\n",
247 | " # for stream\n",
248 | " output = get_response(pload)\n",
249 | " for chunk in output:\n",
250 | " if chunk:\n",
251 | " data = json.loads(chunk.decode())\n",
252 | " if data[\"error_code\"] == 0:\n",
253 | " output = data[\"text\"][len(prompt):].strip()\n",
254 | " state.messages[-1][-1] = output + \"▌\"\n",
255 | " yield (state, state.to_gradio_chatbot())\n",
256 | " else:\n",
257 | " output = data[\"text\"] + \\\n",
258 | " f\" (error_code: {data['error_code']})\"\n",
259 | " state.messages[-1][-1] = output\n",
260 | " yield (state, state.to_gradio_chatbot())\n",
261 | " return\n",
262 | " time.sleep(0.03)\n",
263 | "\n",
264 | " state.messages[-1][-1] = state.messages[-1][-1][:-1]\n",
265 | " yield (state, state.to_gradio_chatbot())\n",
266 | "\n",
267 | "\n",
268 | "def build_demo():\n",
269 | " textbox = gr.Textbox(\n",
270 | " show_label=False, placeholder=\"Enter text and press ENTER\", container=False)\n",
271 | " with gr.Blocks(title=\"ShareGPT4V\", theme=gr.themes.Default(), css=block_css) as demo:\n",
272 | " state = gr.State()\n",
273 | " gr.Markdown(title_markdown)\n",
274 | "\n",
275 | " with gr.Row():\n",
276 | " with gr.Column(scale=5):\n",
277 | " with gr.Row(elem_id=\"Model ID\"):\n",
278 | " gr.Dropdown(\n",
279 | " choices=['ShareGPT4V-7B'],\n",
280 | " value='ShareGPT4V-7B',\n",
281 | " interactive=True,\n",
282 | " label='Model ID',\n",
283 | " container=False)\n",
284 | " imagebox = gr.Image(type=\"pil\")\n",
285 | " image_process_mode = gr.Radio(\n",
286 | " [\"Crop\", \"Resize\", \"Pad\", \"Default\"],\n",
287 | " value=\"Default\",\n",
288 | " label=\"Preprocess for non-square image\", visible=False)\n",
289 | "\n",
290 | " cur_dir = \"/content/InternLM-XComposer/projects/ShareGPT4V\"\n",
291 | " gr.Examples(examples=[\n",
292 | " [f\"{cur_dir}/examples/breaking_bad.png\",\n",
293 | " \"What is the most common catchphrase of the character on the right?\"],\n",
294 | " [f\"{cur_dir}/examples/photo.png\",\n",
295 | " \"From a photography perspective, analyze what makes this picture beautiful?\"],\n",
296 | " ], inputs=[imagebox, textbox])\n",
297 | "\n",
298 | " with gr.Accordion(\"Parameters\", open=False) as _:\n",
299 | " temperature = gr.Slider(\n",
300 | " minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label=\"Temperature\",)\n",
301 | " top_p = gr.Slider(\n",
302 | " minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label=\"Top P\",)\n",
303 | " max_output_tokens = gr.Slider(\n",
304 | " minimum=0, maximum=1024, value=512, step=64, interactive=True, label=\"Max output tokens\",)\n",
305 | "\n",
306 | " with gr.Column(scale=8):\n",
307 | " chatbot = gr.Chatbot(\n",
308 | " elem_id=\"chatbot\", label=\"ShareGPT4V Chatbot\", height=550)\n",
309 | " with gr.Row():\n",
310 | " with gr.Column(scale=8):\n",
311 | " textbox.render()\n",
312 | " with gr.Column(scale=1, min_width=50):\n",
313 | " submit_btn = gr.Button(value=\"Send\", variant=\"primary\")\n",
314 | " with gr.Row(elem_id=\"buttons\") as _:\n",
315 | " regenerate_btn = gr.Button(\n",
316 | " value=\"🔄 Regenerate\", interactive=True)\n",
317 | " clear_btn = gr.Button(value=\"🗑️ Clear\", interactive=True)\n",
318 | "\n",
319 | " gr.Markdown(tos_markdown)\n",
320 | " gr.Markdown(learn_more_markdown)\n",
321 | " gr.Markdown(ack_markdown)\n",
322 | "\n",
323 | " regenerate_btn.click(\n",
324 | " regenerate,\n",
325 | " [state, image_process_mode],\n",
326 | " [state, chatbot, textbox, imagebox],\n",
327 | " queue=False\n",
328 | " ).then(\n",
329 | " http_bot,\n",
330 | " [state, temperature, top_p, max_output_tokens],\n",
331 | " [state, chatbot]\n",
332 | " )\n",
333 | "\n",
334 | " clear_btn.click(\n",
335 | " clear_history,\n",
336 | " None,\n",
337 | " [state, chatbot, textbox, imagebox],\n",
338 | " queue=False\n",
339 | " )\n",
340 | "\n",
341 | " textbox.submit(\n",
342 | " add_text,\n",
343 | " [state, textbox, imagebox, image_process_mode],\n",
344 | " [state, chatbot, textbox, imagebox],\n",
345 | " queue=False\n",
346 | " ).then(\n",
347 | " http_bot,\n",
348 | " [state, temperature, top_p, max_output_tokens],\n",
349 | " [state, chatbot]\n",
350 | " )\n",
351 | "\n",
352 | " submit_btn.click(\n",
353 | " add_text,\n",
354 | " [state, textbox, imagebox, image_process_mode],\n",
355 | " [state, chatbot, textbox, imagebox],\n",
356 | " queue=False\n",
357 | " ).then(\n",
358 | " http_bot,\n",
359 | " [state, temperature, top_p, max_output_tokens],\n",
360 | " [state, chatbot]\n",
361 | " )\n",
362 | "\n",
363 | " demo.load(\n",
364 | " load_demo,\n",
365 | " None,\n",
366 | " [state],\n",
367 | " queue=False\n",
368 | " )\n",
369 | " return demo\n",
370 | "\n",
371 | "model_name = \"llava-v1.5-7b\"\n",
372 | "tokenizer, model, image_processor, context_len = load_pretrained_model(\"4bit/ShareGPT4V-7B-5GB\", None, \"llava-v1.5-7b\", True, False)\n",
373 | "demo = build_demo()\n",
374 | "demo.queue()\n",
375 | "demo.launch(share=True, inline=False, debug=True)"
376 | ]
377 | }
378 | ],
379 | "metadata": {
380 | "accelerator": "GPU",
381 | "colab": {
382 | "gpuType": "T4",
383 | "provenance": []
384 | },
385 | "kernelspec": {
386 | "display_name": "Python 3",
387 | "name": "python3"
388 | },
389 | "language_info": {
390 | "name": "python"
391 | }
392 | },
393 | "nbformat": 4,
394 | "nbformat_minor": 0
395 | }
396 |
--------------------------------------------------------------------------------