├── README.md ├── ShareGPT4V_colab.ipynb └── ShareGPT4V_8bit_colab.ipynb /README.md: -------------------------------------------------------------------------------- 1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 | 5 | # 🚦 WIP 🚦 6 | 7 | ## 🦒 Colab 8 | 9 | | Colab | Info 10 | | --- | --- | 11 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_colab.ipynb) | ShareGPT4V_colab (7B 16bit Pro Colab 😐) 12 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_8bit_colab.ipynb) | ShareGPT4V_8bit_colab (7B 8bit Free Colab T4) 13 | 14 | ## 🦆 Kaggle 15 | 16 | | Kaggle | Info 17 | | --- | --- | 18 | [![open_in_kaggle_small](https://user-images.githubusercontent.com/54370274/228924833-17316feb-d0fe-4249-90ba-682930ba11e5.svg)](https://kaggle.com/camenduru/sharegpt4v) | sharegpt4v_kaggle (7B 16bit Free Kaggle T4) 19 | 20 | ## Main Repo 21 | https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V
22 | https://github.com/haotian-liu/LLaVA
23 | 24 | ## Paper 25 | https://arxiv.org/abs/2311.12793 26 | 27 | ## Page 28 | https://sharegpt4v.github.io/ 29 | 30 | ## Output 31 | ![Screenshot 2023-11-23 114102](https://github.com/camenduru/ShareGPT4V-colab/assets/54370274/3acfa3c6-19c2-40dd-9252-f807c0a08a16) 32 | 33 | -------------------------------------------------------------------------------- /ShareGPT4V_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone -b dev https://github.com/camenduru/InternLM-XComposer\n", 22 | "!pip install -q https://github.com/camenduru/wheels/releases/download/colab/llava-ShareGPT4V-1.1.3-py3-none-any.whl gradio\n", 23 | "%cd /content/InternLM-XComposer/projects/ShareGPT4V\n", 24 | "\n", 25 | "import hashlib\n", 26 | "import json\n", 27 | "import os\n", 28 | "import time\n", 29 | "from threading import Thread\n", 30 | "\n", 31 | "import gradio as gr\n", 32 | "import torch\n", 33 | "from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)\n", 34 | "from llava.conversation import (SeparatorStyle, conv_templates, default_conversation)\n", 35 | "from llava.mm_utils import (KeywordsStoppingCriteria, load_image_from_base64, process_images, tokenizer_image_token)\n", 36 | "from llava.model.builder import load_pretrained_model\n", 37 | "from transformers import TextIteratorStreamer\n", 38 | "\n", 39 | "print(gr.__version__)\n", 40 | "\n", 41 | "block_css = \"\"\"\n", 42 | "\n", 43 | "#buttons button {\n", 44 | " min-width: min(120px,100%);\n", 45 | "}\n", 46 | "\"\"\"\n", 47 | "title_markdown = (\"\"\"\n", 48 | "# 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions\n", 49 | "### 🔊 Notice: The demo of Share-Captioner will soon be supported. Stay tune for updates!\n", 50 | "[[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V)] | 📚 [[Paper](https://arxiv.org/pdf/2311.12793.pdf)]\n", 51 | "\"\"\")\n", 52 | "tos_markdown = (\"\"\"\n", 53 | "### Terms of use\n", 54 | "By using this service, users are required to agree to the following terms:\n", 55 | "The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\n", 56 | "For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n", 57 | "\"\"\")\n", 58 | "learn_more_markdown = (\"\"\"\n", 59 | "### License\n", 60 | "The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n", 61 | "\"\"\")\n", 62 | "ack_markdown = (\"\"\"\n", 63 | "### Acknowledgement\n", 64 | "The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!\n", 65 | "\"\"\")\n", 66 | "\n", 67 | "\n", 68 | "def regenerate(state, image_process_mode):\n", 69 | " state.messages[-1][-1] = None\n", 70 | " prev_human_msg = state.messages[-2]\n", 71 | " if type(prev_human_msg[1]) in (tuple, list):\n", 72 | " prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)\n", 73 | " state.skip_next = False\n", 74 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 75 | "\n", 76 | "\n", 77 | "def clear_history():\n", 78 | " state = default_conversation.copy()\n", 79 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 80 | "\n", 81 | "\n", 82 | "def add_text(state, text, image, image_process_mode):\n", 83 | " if len(text) <= 0 and image is None:\n", 84 | " state.skip_next = True\n", 85 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 86 | "\n", 87 | " text = text[:1536] # Hard cut-off\n", 88 | " if image is not None:\n", 89 | " text = text[:1200] # Hard cut-off for images\n", 90 | " if '' not in text:\n", 91 | " # text = '' + text\n", 92 | " text = text + '\\n'\n", 93 | " text = (text, image, image_process_mode)\n", 94 | " if len(state.get_images(return_pil=True)) > 0:\n", 95 | " state = default_conversation.copy()\n", 96 | " state.append_message(state.roles[0], text)\n", 97 | " state.append_message(state.roles[1], None)\n", 98 | " state.skip_next = False\n", 99 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 100 | "\n", 101 | "\n", 102 | "def load_demo():\n", 103 | " state = default_conversation.copy()\n", 104 | " return state\n", 105 | "\n", 106 | "@torch.inference_mode()\n", 107 | "def get_response(params):\n", 108 | " prompt = params[\"prompt\"]\n", 109 | " ori_prompt = prompt\n", 110 | " images = params.get(\"images\", None)\n", 111 | " num_image_tokens = 0\n", 112 | " if images is not None and len(images) > 0:\n", 113 | " if len(images) > 0:\n", 114 | " if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):\n", 115 | " raise ValueError(\n", 116 | " \"Number of images does not match number of tokens in prompt\")\n", 117 | "\n", 118 | " images = [load_image_from_base64(image) for image in images]\n", 119 | " images = process_images(images, image_processor, model.config)\n", 120 | "\n", 121 | " if type(images) is list:\n", 122 | " images = [image.to(model.device, dtype=torch.float16)\n", 123 | " for image in images]\n", 124 | " else:\n", 125 | " images = images.to(model.device, dtype=torch.float16)\n", 126 | "\n", 127 | " replace_token = DEFAULT_IMAGE_TOKEN\n", 128 | " if getattr(model.config, 'mm_use_im_start_end', False):\n", 129 | " replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n", 130 | " prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)\n", 131 | "\n", 132 | " num_image_tokens = prompt.count(\n", 133 | " replace_token) * model.get_vision_tower().num_patches\n", 134 | " else:\n", 135 | " images = None\n", 136 | " image_args = {\"images\": images}\n", 137 | " else:\n", 138 | " images = None\n", 139 | " image_args = {}\n", 140 | "\n", 141 | " temperature = float(params.get(\"temperature\", 1.0))\n", 142 | " top_p = float(params.get(\"top_p\", 1.0))\n", 143 | " max_context_length = getattr(\n", 144 | " model.config, 'max_position_embeddings', 2048)\n", 145 | " max_new_tokens = min(int(params.get(\"max_new_tokens\", 256)), 1024)\n", 146 | " stop_str = params.get(\"stop\", None)\n", 147 | " do_sample = True if temperature > 0.001 else False\n", 148 | "\n", 149 | " input_ids = tokenizer_image_token(\n", 150 | " prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)\n", 151 | " keywords = [stop_str]\n", 152 | " stopping_criteria = KeywordsStoppingCriteria(\n", 153 | " keywords, tokenizer, input_ids)\n", 154 | " streamer = TextIteratorStreamer(\n", 155 | " tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)\n", 156 | "\n", 157 | " max_new_tokens = min(max_new_tokens, max_context_length -\n", 158 | " input_ids.shape[-1] - num_image_tokens)\n", 159 | "\n", 160 | " if max_new_tokens < 1:\n", 161 | " yield json.dumps({\"text\": ori_prompt + \"Exceeds max token length. Please start a new conversation, thanks.\", \"error_code\": 0}).encode() + b\"\\0\"\n", 162 | " return\n", 163 | "\n", 164 | " # local inference\n", 165 | " thread = Thread(target=model.generate, kwargs=dict(\n", 166 | " inputs=input_ids,\n", 167 | " do_sample=do_sample,\n", 168 | " temperature=temperature,\n", 169 | " top_p=top_p,\n", 170 | " max_new_tokens=max_new_tokens,\n", 171 | " streamer=streamer,\n", 172 | " stopping_criteria=[stopping_criteria],\n", 173 | " use_cache=True,\n", 174 | " **image_args\n", 175 | " ))\n", 176 | " thread.start()\n", 177 | "\n", 178 | " generated_text = ori_prompt\n", 179 | " for new_text in streamer:\n", 180 | " generated_text += new_text\n", 181 | " if generated_text.endswith(stop_str):\n", 182 | " generated_text = generated_text[:-len(stop_str)]\n", 183 | " yield json.dumps({\"text\": generated_text, \"error_code\": 0}).encode()\n", 184 | "\n", 185 | "\n", 186 | "def http_bot(state, temperature, top_p, max_new_tokens):\n", 187 | " if state.skip_next:\n", 188 | " # This generate call is skipped due to invalid inputs\n", 189 | " yield (state, state.to_gradio_chatbot())\n", 190 | " return\n", 191 | "\n", 192 | " if len(state.messages) == state.offset + 2:\n", 193 | " # First round of conversation\n", 194 | " if \"llava\" in model_name.lower():\n", 195 | " if 'llama-2' in model_name.lower():\n", 196 | " template_name = \"llava_llama_2\"\n", 197 | " elif \"v1\" in model_name.lower():\n", 198 | " if 'mmtag' in model_name.lower():\n", 199 | " template_name = \"v1_mmtag\"\n", 200 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n", 201 | " template_name = \"v1_mmtag\"\n", 202 | " else:\n", 203 | " template_name = \"llava_v1\"\n", 204 | " elif \"mpt\" in model_name.lower():\n", 205 | " template_name = \"mpt\"\n", 206 | " else:\n", 207 | " if 'mmtag' in model_name.lower():\n", 208 | " template_name = \"v0_mmtag\"\n", 209 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n", 210 | " template_name = \"v0_mmtag\"\n", 211 | " else:\n", 212 | " template_name = \"llava_v0\"\n", 213 | " elif \"mpt\" in model_name:\n", 214 | " template_name = \"mpt_text\"\n", 215 | " elif \"llama-2\" in model_name:\n", 216 | " template_name = \"llama_2\"\n", 217 | " else:\n", 218 | " template_name = \"vicuna_v1\"\n", 219 | " new_state = conv_templates[template_name].copy()\n", 220 | " new_state.append_message(new_state.roles[0], state.messages[-2][1])\n", 221 | " new_state.append_message(new_state.roles[1], None)\n", 222 | " state = new_state\n", 223 | "\n", 224 | " # Construct prompt\n", 225 | " prompt = state.get_prompt()\n", 226 | "\n", 227 | " all_images = state.get_images(return_pil=True)\n", 228 | " all_image_hash = [hashlib.md5(image.tobytes()).hexdigest()\n", 229 | " for image in all_images]\n", 230 | "\n", 231 | " # Make requests\n", 232 | " pload = {\n", 233 | " \"model\": model_name,\n", 234 | " \"prompt\": prompt,\n", 235 | " \"temperature\": float(temperature),\n", 236 | " \"top_p\": float(top_p),\n", 237 | " \"max_new_tokens\": min(int(max_new_tokens), 1536),\n", 238 | " \"stop\": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,\n", 239 | " \"images\": f'List of {len(state.get_images())} images: {all_image_hash}',\n", 240 | " }\n", 241 | "\n", 242 | " pload['images'] = state.get_images()\n", 243 | "\n", 244 | " state.messages[-1][-1] = \"▌\"\n", 245 | " yield (state, state.to_gradio_chatbot())\n", 246 | "\n", 247 | " # for stream\n", 248 | " output = get_response(pload)\n", 249 | " for chunk in output:\n", 250 | " if chunk:\n", 251 | " data = json.loads(chunk.decode())\n", 252 | " if data[\"error_code\"] == 0:\n", 253 | " output = data[\"text\"][len(prompt):].strip()\n", 254 | " state.messages[-1][-1] = output + \"▌\"\n", 255 | " yield (state, state.to_gradio_chatbot())\n", 256 | " else:\n", 257 | " output = data[\"text\"] + \\\n", 258 | " f\" (error_code: {data['error_code']})\"\n", 259 | " state.messages[-1][-1] = output\n", 260 | " yield (state, state.to_gradio_chatbot())\n", 261 | " return\n", 262 | " time.sleep(0.03)\n", 263 | "\n", 264 | " state.messages[-1][-1] = state.messages[-1][-1][:-1]\n", 265 | " yield (state, state.to_gradio_chatbot())\n", 266 | "\n", 267 | "\n", 268 | "def build_demo():\n", 269 | " textbox = gr.Textbox(\n", 270 | " show_label=False, placeholder=\"Enter text and press ENTER\", container=False)\n", 271 | " with gr.Blocks(title=\"ShareGPT4V\", theme=gr.themes.Default(), css=block_css) as demo:\n", 272 | " state = gr.State()\n", 273 | " gr.Markdown(title_markdown)\n", 274 | "\n", 275 | " with gr.Row():\n", 276 | " with gr.Column(scale=5):\n", 277 | " with gr.Row(elem_id=\"Model ID\"):\n", 278 | " gr.Dropdown(\n", 279 | " choices=['ShareGPT4V-7B'],\n", 280 | " value='ShareGPT4V-7B',\n", 281 | " interactive=True,\n", 282 | " label='Model ID',\n", 283 | " container=False)\n", 284 | " imagebox = gr.Image(type=\"pil\")\n", 285 | " image_process_mode = gr.Radio(\n", 286 | " [\"Crop\", \"Resize\", \"Pad\", \"Default\"],\n", 287 | " value=\"Default\",\n", 288 | " label=\"Preprocess for non-square image\", visible=False)\n", 289 | "\n", 290 | " cur_dir = \"/content/InternLM-XComposer/projects/ShareGPT4V\"\n", 291 | " gr.Examples(examples=[\n", 292 | " [f\"{cur_dir}/examples/breaking_bad.png\",\n", 293 | " \"What is the most common catchphrase of the character on the right?\"],\n", 294 | " [f\"{cur_dir}/examples/photo.png\",\n", 295 | " \"From a photography perspective, analyze what makes this picture beautiful?\"],\n", 296 | " ], inputs=[imagebox, textbox])\n", 297 | "\n", 298 | " with gr.Accordion(\"Parameters\", open=False) as _:\n", 299 | " temperature = gr.Slider(\n", 300 | " minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label=\"Temperature\",)\n", 301 | " top_p = gr.Slider(\n", 302 | " minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label=\"Top P\",)\n", 303 | " max_output_tokens = gr.Slider(\n", 304 | " minimum=0, maximum=1024, value=512, step=64, interactive=True, label=\"Max output tokens\",)\n", 305 | "\n", 306 | " with gr.Column(scale=8):\n", 307 | " chatbot = gr.Chatbot(\n", 308 | " elem_id=\"chatbot\", label=\"ShareGPT4V Chatbot\", height=550)\n", 309 | " with gr.Row():\n", 310 | " with gr.Column(scale=8):\n", 311 | " textbox.render()\n", 312 | " with gr.Column(scale=1, min_width=50):\n", 313 | " submit_btn = gr.Button(value=\"Send\", variant=\"primary\")\n", 314 | " with gr.Row(elem_id=\"buttons\") as _:\n", 315 | " regenerate_btn = gr.Button(\n", 316 | " value=\"🔄 Regenerate\", interactive=True)\n", 317 | " clear_btn = gr.Button(value=\"🗑️ Clear\", interactive=True)\n", 318 | "\n", 319 | " gr.Markdown(tos_markdown)\n", 320 | " gr.Markdown(learn_more_markdown)\n", 321 | " gr.Markdown(ack_markdown)\n", 322 | "\n", 323 | " regenerate_btn.click(\n", 324 | " regenerate,\n", 325 | " [state, image_process_mode],\n", 326 | " [state, chatbot, textbox, imagebox],\n", 327 | " queue=False\n", 328 | " ).then(\n", 329 | " http_bot,\n", 330 | " [state, temperature, top_p, max_output_tokens],\n", 331 | " [state, chatbot]\n", 332 | " )\n", 333 | "\n", 334 | " clear_btn.click(\n", 335 | " clear_history,\n", 336 | " None,\n", 337 | " [state, chatbot, textbox, imagebox],\n", 338 | " queue=False\n", 339 | " )\n", 340 | "\n", 341 | " textbox.submit(\n", 342 | " add_text,\n", 343 | " [state, textbox, imagebox, image_process_mode],\n", 344 | " [state, chatbot, textbox, imagebox],\n", 345 | " queue=False\n", 346 | " ).then(\n", 347 | " http_bot,\n", 348 | " [state, temperature, top_p, max_output_tokens],\n", 349 | " [state, chatbot]\n", 350 | " )\n", 351 | "\n", 352 | " submit_btn.click(\n", 353 | " add_text,\n", 354 | " [state, textbox, imagebox, image_process_mode],\n", 355 | " [state, chatbot, textbox, imagebox],\n", 356 | " queue=False\n", 357 | " ).then(\n", 358 | " http_bot,\n", 359 | " [state, temperature, top_p, max_output_tokens],\n", 360 | " [state, chatbot]\n", 361 | " )\n", 362 | "\n", 363 | " demo.load(\n", 364 | " load_demo,\n", 365 | " None,\n", 366 | " [state],\n", 367 | " queue=False\n", 368 | " )\n", 369 | " return demo\n", 370 | "\n", 371 | "model_name = \"llava-v1.5-7b\"\n", 372 | "tokenizer, model, image_processor, context_len = load_pretrained_model(\"4bit/ShareGPT4V-7B-5GB\", None, \"llava-v1.5-7b\", False, False)\n", 373 | "demo = build_demo()\n", 374 | "demo.queue()\n", 375 | "demo.launch(share=True, inline=False, debug=True)" 376 | ] 377 | } 378 | ], 379 | "metadata": { 380 | "accelerator": "GPU", 381 | "colab": { 382 | "gpuType": "T4", 383 | "provenance": [] 384 | }, 385 | "kernelspec": { 386 | "display_name": "Python 3", 387 | "name": "python3" 388 | }, 389 | "language_info": { 390 | "name": "python" 391 | } 392 | }, 393 | "nbformat": 4, 394 | "nbformat_minor": 0 395 | } 396 | -------------------------------------------------------------------------------- /ShareGPT4V_8bit_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/ShareGPT4V-colab/blob/main/ShareGPT4V_8bit_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": { 16 | "id": "VjYy0F2gZIPR" 17 | }, 18 | "outputs": [], 19 | "source": [ 20 | "%cd /content\n", 21 | "!git clone -b dev https://github.com/camenduru/InternLM-XComposer\n", 22 | "!pip install -q https://github.com/camenduru/wheels/releases/download/colab/llava-ShareGPT4V-1.1.3-py3-none-any.whl gradio\n", 23 | "%cd /content/InternLM-XComposer/projects/ShareGPT4V\n", 24 | "\n", 25 | "import hashlib\n", 26 | "import json\n", 27 | "import os\n", 28 | "import time\n", 29 | "from threading import Thread\n", 30 | "\n", 31 | "import gradio as gr\n", 32 | "import torch\n", 33 | "from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)\n", 34 | "from llava.conversation import (SeparatorStyle, conv_templates, default_conversation)\n", 35 | "from llava.mm_utils import (KeywordsStoppingCriteria, load_image_from_base64, process_images, tokenizer_image_token)\n", 36 | "from llava.model.builder import load_pretrained_model\n", 37 | "from transformers import TextIteratorStreamer\n", 38 | "\n", 39 | "print(gr.__version__)\n", 40 | "\n", 41 | "block_css = \"\"\"\n", 42 | "\n", 43 | "#buttons button {\n", 44 | " min-width: min(120px,100%);\n", 45 | "}\n", 46 | "\"\"\"\n", 47 | "title_markdown = (\"\"\"\n", 48 | "# 🐬 ShareGPT4V: Improving Large Multi-modal Models with Better Captions\n", 49 | "### 🔊 Notice: The demo of Share-Captioner will soon be supported. Stay tune for updates!\n", 50 | "[[Project Page](https://sharegpt4v.github.io/)] [[Code](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V)] | 📚 [[Paper](https://arxiv.org/pdf/2311.12793.pdf)]\n", 51 | "\"\"\")\n", 52 | "tos_markdown = (\"\"\"\n", 53 | "### Terms of use\n", 54 | "By using this service, users are required to agree to the following terms:\n", 55 | "The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes.\n", 56 | "For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.\n", 57 | "\"\"\")\n", 58 | "learn_more_markdown = (\"\"\"\n", 59 | "### License\n", 60 | "The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.\n", 61 | "\"\"\")\n", 62 | "ack_markdown = (\"\"\"\n", 63 | "### Acknowledgement\n", 64 | "The template for this web demo is from [LLaVA](https://github.com/haotian-liu/LLaVA), and we are very grateful to LLaVA for their open source contributions to the community!\n", 65 | "\"\"\")\n", 66 | "\n", 67 | "\n", 68 | "def regenerate(state, image_process_mode):\n", 69 | " state.messages[-1][-1] = None\n", 70 | " prev_human_msg = state.messages[-2]\n", 71 | " if type(prev_human_msg[1]) in (tuple, list):\n", 72 | " prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)\n", 73 | " state.skip_next = False\n", 74 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 75 | "\n", 76 | "\n", 77 | "def clear_history():\n", 78 | " state = default_conversation.copy()\n", 79 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 80 | "\n", 81 | "\n", 82 | "def add_text(state, text, image, image_process_mode):\n", 83 | " if len(text) <= 0 and image is None:\n", 84 | " state.skip_next = True\n", 85 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 86 | "\n", 87 | " text = text[:1536] # Hard cut-off\n", 88 | " if image is not None:\n", 89 | " text = text[:1200] # Hard cut-off for images\n", 90 | " if '' not in text:\n", 91 | " # text = '' + text\n", 92 | " text = text + '\\n'\n", 93 | " text = (text, image, image_process_mode)\n", 94 | " if len(state.get_images(return_pil=True)) > 0:\n", 95 | " state = default_conversation.copy()\n", 96 | " state.append_message(state.roles[0], text)\n", 97 | " state.append_message(state.roles[1], None)\n", 98 | " state.skip_next = False\n", 99 | " return (state, state.to_gradio_chatbot(), \"\", None)\n", 100 | "\n", 101 | "\n", 102 | "def load_demo():\n", 103 | " state = default_conversation.copy()\n", 104 | " return state\n", 105 | "\n", 106 | "@torch.inference_mode()\n", 107 | "def get_response(params):\n", 108 | " prompt = params[\"prompt\"]\n", 109 | " ori_prompt = prompt\n", 110 | " images = params.get(\"images\", None)\n", 111 | " num_image_tokens = 0\n", 112 | " if images is not None and len(images) > 0:\n", 113 | " if len(images) > 0:\n", 114 | " if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):\n", 115 | " raise ValueError(\n", 116 | " \"Number of images does not match number of tokens in prompt\")\n", 117 | "\n", 118 | " images = [load_image_from_base64(image) for image in images]\n", 119 | " images = process_images(images, image_processor, model.config)\n", 120 | "\n", 121 | " if type(images) is list:\n", 122 | " images = [image.to(model.device, dtype=torch.float16)\n", 123 | " for image in images]\n", 124 | " else:\n", 125 | " images = images.to(model.device, dtype=torch.float16)\n", 126 | "\n", 127 | " replace_token = DEFAULT_IMAGE_TOKEN\n", 128 | " if getattr(model.config, 'mm_use_im_start_end', False):\n", 129 | " replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN\n", 130 | " prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)\n", 131 | "\n", 132 | " num_image_tokens = prompt.count(\n", 133 | " replace_token) * model.get_vision_tower().num_patches\n", 134 | " else:\n", 135 | " images = None\n", 136 | " image_args = {\"images\": images}\n", 137 | " else:\n", 138 | " images = None\n", 139 | " image_args = {}\n", 140 | "\n", 141 | " temperature = float(params.get(\"temperature\", 1.0))\n", 142 | " top_p = float(params.get(\"top_p\", 1.0))\n", 143 | " max_context_length = getattr(\n", 144 | " model.config, 'max_position_embeddings', 2048)\n", 145 | " max_new_tokens = min(int(params.get(\"max_new_tokens\", 256)), 1024)\n", 146 | " stop_str = params.get(\"stop\", None)\n", 147 | " do_sample = True if temperature > 0.001 else False\n", 148 | "\n", 149 | " input_ids = tokenizer_image_token(\n", 150 | " prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)\n", 151 | " keywords = [stop_str]\n", 152 | " stopping_criteria = KeywordsStoppingCriteria(\n", 153 | " keywords, tokenizer, input_ids)\n", 154 | " streamer = TextIteratorStreamer(\n", 155 | " tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)\n", 156 | "\n", 157 | " max_new_tokens = min(max_new_tokens, max_context_length -\n", 158 | " input_ids.shape[-1] - num_image_tokens)\n", 159 | "\n", 160 | " if max_new_tokens < 1:\n", 161 | " yield json.dumps({\"text\": ori_prompt + \"Exceeds max token length. Please start a new conversation, thanks.\", \"error_code\": 0}).encode() + b\"\\0\"\n", 162 | " return\n", 163 | "\n", 164 | " # local inference\n", 165 | " thread = Thread(target=model.generate, kwargs=dict(\n", 166 | " inputs=input_ids,\n", 167 | " do_sample=do_sample,\n", 168 | " temperature=temperature,\n", 169 | " top_p=top_p,\n", 170 | " max_new_tokens=max_new_tokens,\n", 171 | " streamer=streamer,\n", 172 | " stopping_criteria=[stopping_criteria],\n", 173 | " use_cache=True,\n", 174 | " **image_args\n", 175 | " ))\n", 176 | " thread.start()\n", 177 | "\n", 178 | " generated_text = ori_prompt\n", 179 | " for new_text in streamer:\n", 180 | " generated_text += new_text\n", 181 | " if generated_text.endswith(stop_str):\n", 182 | " generated_text = generated_text[:-len(stop_str)]\n", 183 | " yield json.dumps({\"text\": generated_text, \"error_code\": 0}).encode()\n", 184 | "\n", 185 | "\n", 186 | "def http_bot(state, temperature, top_p, max_new_tokens):\n", 187 | " if state.skip_next:\n", 188 | " # This generate call is skipped due to invalid inputs\n", 189 | " yield (state, state.to_gradio_chatbot())\n", 190 | " return\n", 191 | "\n", 192 | " if len(state.messages) == state.offset + 2:\n", 193 | " # First round of conversation\n", 194 | " if \"llava\" in model_name.lower():\n", 195 | " if 'llama-2' in model_name.lower():\n", 196 | " template_name = \"llava_llama_2\"\n", 197 | " elif \"v1\" in model_name.lower():\n", 198 | " if 'mmtag' in model_name.lower():\n", 199 | " template_name = \"v1_mmtag\"\n", 200 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n", 201 | " template_name = \"v1_mmtag\"\n", 202 | " else:\n", 203 | " template_name = \"llava_v1\"\n", 204 | " elif \"mpt\" in model_name.lower():\n", 205 | " template_name = \"mpt\"\n", 206 | " else:\n", 207 | " if 'mmtag' in model_name.lower():\n", 208 | " template_name = \"v0_mmtag\"\n", 209 | " elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():\n", 210 | " template_name = \"v0_mmtag\"\n", 211 | " else:\n", 212 | " template_name = \"llava_v0\"\n", 213 | " elif \"mpt\" in model_name:\n", 214 | " template_name = \"mpt_text\"\n", 215 | " elif \"llama-2\" in model_name:\n", 216 | " template_name = \"llama_2\"\n", 217 | " else:\n", 218 | " template_name = \"vicuna_v1\"\n", 219 | " new_state = conv_templates[template_name].copy()\n", 220 | " new_state.append_message(new_state.roles[0], state.messages[-2][1])\n", 221 | " new_state.append_message(new_state.roles[1], None)\n", 222 | " state = new_state\n", 223 | "\n", 224 | " # Construct prompt\n", 225 | " prompt = state.get_prompt()\n", 226 | "\n", 227 | " all_images = state.get_images(return_pil=True)\n", 228 | " all_image_hash = [hashlib.md5(image.tobytes()).hexdigest()\n", 229 | " for image in all_images]\n", 230 | "\n", 231 | " # Make requests\n", 232 | " pload = {\n", 233 | " \"model\": model_name,\n", 234 | " \"prompt\": prompt,\n", 235 | " \"temperature\": float(temperature),\n", 236 | " \"top_p\": float(top_p),\n", 237 | " \"max_new_tokens\": min(int(max_new_tokens), 1536),\n", 238 | " \"stop\": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,\n", 239 | " \"images\": f'List of {len(state.get_images())} images: {all_image_hash}',\n", 240 | " }\n", 241 | "\n", 242 | " pload['images'] = state.get_images()\n", 243 | "\n", 244 | " state.messages[-1][-1] = \"▌\"\n", 245 | " yield (state, state.to_gradio_chatbot())\n", 246 | "\n", 247 | " # for stream\n", 248 | " output = get_response(pload)\n", 249 | " for chunk in output:\n", 250 | " if chunk:\n", 251 | " data = json.loads(chunk.decode())\n", 252 | " if data[\"error_code\"] == 0:\n", 253 | " output = data[\"text\"][len(prompt):].strip()\n", 254 | " state.messages[-1][-1] = output + \"▌\"\n", 255 | " yield (state, state.to_gradio_chatbot())\n", 256 | " else:\n", 257 | " output = data[\"text\"] + \\\n", 258 | " f\" (error_code: {data['error_code']})\"\n", 259 | " state.messages[-1][-1] = output\n", 260 | " yield (state, state.to_gradio_chatbot())\n", 261 | " return\n", 262 | " time.sleep(0.03)\n", 263 | "\n", 264 | " state.messages[-1][-1] = state.messages[-1][-1][:-1]\n", 265 | " yield (state, state.to_gradio_chatbot())\n", 266 | "\n", 267 | "\n", 268 | "def build_demo():\n", 269 | " textbox = gr.Textbox(\n", 270 | " show_label=False, placeholder=\"Enter text and press ENTER\", container=False)\n", 271 | " with gr.Blocks(title=\"ShareGPT4V\", theme=gr.themes.Default(), css=block_css) as demo:\n", 272 | " state = gr.State()\n", 273 | " gr.Markdown(title_markdown)\n", 274 | "\n", 275 | " with gr.Row():\n", 276 | " with gr.Column(scale=5):\n", 277 | " with gr.Row(elem_id=\"Model ID\"):\n", 278 | " gr.Dropdown(\n", 279 | " choices=['ShareGPT4V-7B'],\n", 280 | " value='ShareGPT4V-7B',\n", 281 | " interactive=True,\n", 282 | " label='Model ID',\n", 283 | " container=False)\n", 284 | " imagebox = gr.Image(type=\"pil\")\n", 285 | " image_process_mode = gr.Radio(\n", 286 | " [\"Crop\", \"Resize\", \"Pad\", \"Default\"],\n", 287 | " value=\"Default\",\n", 288 | " label=\"Preprocess for non-square image\", visible=False)\n", 289 | "\n", 290 | " cur_dir = \"/content/InternLM-XComposer/projects/ShareGPT4V\"\n", 291 | " gr.Examples(examples=[\n", 292 | " [f\"{cur_dir}/examples/breaking_bad.png\",\n", 293 | " \"What is the most common catchphrase of the character on the right?\"],\n", 294 | " [f\"{cur_dir}/examples/photo.png\",\n", 295 | " \"From a photography perspective, analyze what makes this picture beautiful?\"],\n", 296 | " ], inputs=[imagebox, textbox])\n", 297 | "\n", 298 | " with gr.Accordion(\"Parameters\", open=False) as _:\n", 299 | " temperature = gr.Slider(\n", 300 | " minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label=\"Temperature\",)\n", 301 | " top_p = gr.Slider(\n", 302 | " minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label=\"Top P\",)\n", 303 | " max_output_tokens = gr.Slider(\n", 304 | " minimum=0, maximum=1024, value=512, step=64, interactive=True, label=\"Max output tokens\",)\n", 305 | "\n", 306 | " with gr.Column(scale=8):\n", 307 | " chatbot = gr.Chatbot(\n", 308 | " elem_id=\"chatbot\", label=\"ShareGPT4V Chatbot\", height=550)\n", 309 | " with gr.Row():\n", 310 | " with gr.Column(scale=8):\n", 311 | " textbox.render()\n", 312 | " with gr.Column(scale=1, min_width=50):\n", 313 | " submit_btn = gr.Button(value=\"Send\", variant=\"primary\")\n", 314 | " with gr.Row(elem_id=\"buttons\") as _:\n", 315 | " regenerate_btn = gr.Button(\n", 316 | " value=\"🔄 Regenerate\", interactive=True)\n", 317 | " clear_btn = gr.Button(value=\"🗑️ Clear\", interactive=True)\n", 318 | "\n", 319 | " gr.Markdown(tos_markdown)\n", 320 | " gr.Markdown(learn_more_markdown)\n", 321 | " gr.Markdown(ack_markdown)\n", 322 | "\n", 323 | " regenerate_btn.click(\n", 324 | " regenerate,\n", 325 | " [state, image_process_mode],\n", 326 | " [state, chatbot, textbox, imagebox],\n", 327 | " queue=False\n", 328 | " ).then(\n", 329 | " http_bot,\n", 330 | " [state, temperature, top_p, max_output_tokens],\n", 331 | " [state, chatbot]\n", 332 | " )\n", 333 | "\n", 334 | " clear_btn.click(\n", 335 | " clear_history,\n", 336 | " None,\n", 337 | " [state, chatbot, textbox, imagebox],\n", 338 | " queue=False\n", 339 | " )\n", 340 | "\n", 341 | " textbox.submit(\n", 342 | " add_text,\n", 343 | " [state, textbox, imagebox, image_process_mode],\n", 344 | " [state, chatbot, textbox, imagebox],\n", 345 | " queue=False\n", 346 | " ).then(\n", 347 | " http_bot,\n", 348 | " [state, temperature, top_p, max_output_tokens],\n", 349 | " [state, chatbot]\n", 350 | " )\n", 351 | "\n", 352 | " submit_btn.click(\n", 353 | " add_text,\n", 354 | " [state, textbox, imagebox, image_process_mode],\n", 355 | " [state, chatbot, textbox, imagebox],\n", 356 | " queue=False\n", 357 | " ).then(\n", 358 | " http_bot,\n", 359 | " [state, temperature, top_p, max_output_tokens],\n", 360 | " [state, chatbot]\n", 361 | " )\n", 362 | "\n", 363 | " demo.load(\n", 364 | " load_demo,\n", 365 | " None,\n", 366 | " [state],\n", 367 | " queue=False\n", 368 | " )\n", 369 | " return demo\n", 370 | "\n", 371 | "model_name = \"llava-v1.5-7b\"\n", 372 | "tokenizer, model, image_processor, context_len = load_pretrained_model(\"4bit/ShareGPT4V-7B-5GB\", None, \"llava-v1.5-7b\", True, False)\n", 373 | "demo = build_demo()\n", 374 | "demo.queue()\n", 375 | "demo.launch(share=True, inline=False, debug=True)" 376 | ] 377 | } 378 | ], 379 | "metadata": { 380 | "accelerator": "GPU", 381 | "colab": { 382 | "gpuType": "T4", 383 | "provenance": [] 384 | }, 385 | "kernelspec": { 386 | "display_name": "Python 3", 387 | "name": "python3" 388 | }, 389 | "language_info": { 390 | "name": "python" 391 | } 392 | }, 393 | "nbformat": 4, 394 | "nbformat_minor": 0 395 | } 396 | --------------------------------------------------------------------------------