├── .gitattributes ├── .gitignore ├── ChatCaptioner ├── README.md ├── caption.ipynb ├── chatcaptioner │ ├── __init__.py │ ├── blip2.py │ ├── chat.py │ ├── clip.py │ └── utils.py ├── demo.py ├── demo_pic │ ├── CuteCloud_1366x768.jpg │ ├── demo1.gif │ ├── demo2.gif │ └── overview.png ├── draw_quali.ipynb ├── environment.yml ├── main_caption.py ├── not_sure.ipynb ├── obj_cover.ipynb ├── question_analysis.ipynb ├── visualization.ipynb └── yes_no.ipynb ├── LICENSE.md ├── README.md └── Video_ChatCaptioner ├── README.md ├── chatcaptioner ├── __init__.py ├── blip2.py ├── clip.py ├── utils.py ├── video_chat.py └── video_reader.py ├── demo_pic ├── dance.gif ├── overview.png └── skating.gif ├── environment.yml ├── generate_caption_msvd.py ├── generate_caption_webvid.py ├── run_msvd.sh └── run_webvid.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb filter=strip-notebook-output 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea 3 | .DS_Store 4 | .ipynb_checkpoints 5 | datasets 6 | datasets_archive 7 | experiments 8 | __pycache__ 9 | 10 | -------------------------------------------------------------------------------- /ChatCaptioner/README.md: -------------------------------------------------------------------------------- 1 | # ChatGPT Asks, BLIP-2 Answers: Automatic Questioning Towards Enriched Visual Descriptions 2 | 3 | Official repository of **ChatCaptioner**. 4 | We discover the powerful questioning ability of LLMs and their great potential for acquiring information effectively. 5 | As an exploration, we introduce ChatCaptioner in image captioning. 6 | ChatCaptioner enrichs the image caption of BLIP-2 by 7 | prompting ChatGPT to keep asking informative questions to BLIP-2 8 | and summarize the conversation at the end as the final caption. 9 | 10 | See our paper [ChatGPT Asks, BLIP-2 Answers: Automatic Questioning Towards Enriched Visual Descriptions](https://arxiv.org/abs/2303.06594) 11 | 12 | ## Demo 13 | ![demo1](demo_pic/demo1.gif) 14 | ![demo2](demo_pic/demo2.gif) 15 | 16 | 17 | ## System Architecture 18 | ![overfiew](demo_pic/overview.png) 19 | 20 | 21 | 22 | ## Installation 23 | Note that you need a GPU with 24G memory to run ChatCaptioner due to the size of BLIP-2. 24 | 25 | To start, git clone this repository first. 26 | 27 | To install and activate the environment, run the following command: 28 | 29 | ``` 30 | conda env create -f environment.yml 31 | conda activate chatcap 32 | ``` 33 | 34 | Set the environment variable OPENAI_API_KEY to your OpenAI API Key. 35 | 36 | ``` 37 | export OPENAI_API_KEY=Your_OpenAI_Key 38 | ``` 39 | You can add it to .bashrc so you don't need to set it manually everytime. 40 | 41 | 42 | As many scripts here are in jupyter notebook, don't forget to add the environment to jupyter's kernel list. 43 | To do so, run 44 | 45 | ``` 46 | python -m ipykernel install --user --name=chatcap 47 | ``` 48 | 49 | 50 | Download our dataset samples from [here](https://drive.google.com/file/d/19yQP9lepLeS2_vSHnYPeOdfQz8OI1e6V/view?usp=share_link) and extract the zip file to the root folder. 51 | After the extraction, the datafolder should look like this. You can skip this step if you only want to run demo.py with your own images. 52 | 53 | ``` 54 | . 55 | ├── chatcaptioner 56 | ├── datasets 57 | │ ├── artemis 58 | │ ├── coco_val 59 | │ └── cc_val 60 | │ ├── annotation.yaml 61 | │ └── img 62 | │ ├── annotation.yaml 63 | │ ├── 85.jpg 64 | │ ... 65 | ├── caption.ipynb 66 | ... 67 | ``` 68 | 69 | 70 | 71 | 72 | ## Usage 73 | To play with ChatCaptioner with a given image, run the following command. It will use GPU 0. 74 | ``` 75 | python demo.py 76 | ``` 77 | 78 | To play with ChatCaptioner with a few dataset samples, check the jupyter script 'caption.ipynb'. 79 | 80 | To caption all the images in the datasets, run 'main_caption.py'. 81 | Using --exp_tag to tag your runs and using --datasets to specify the datasets you want to caption. 82 | 83 | ``` 84 | # caption all the sampled images in the datasets 'cc_val' and 'artemis' using GPU-0 and save results to experiments/test 85 | python main_caption.py --exp_tag test --datasets cc_val artemis --device_id 0 86 | ``` 87 | 88 | Datasets available are 'artemis', 'cc_val', 'coco_val', 'pascal' 89 | 90 | + [Artemis](https://www.artemisdataset.org/) 91 | + [MSCOCO](https://cocodataset.org/#home) 92 | + [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/) 93 | + [Pascal VOC](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/) 94 | 95 | ## Visualization 96 | 97 | To visualize the caption results, check the jupyter script 'visualization.ipynb'. 98 | 99 | 100 | ## Acknowledgement 101 | 102 | + [ChatGPT](https://openai.com/blog/chatgpt/) 103 | + [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) 104 | 105 | 106 | Please cite ChatCaptioner from the following bibtex 107 | 108 | ``` 109 | @article{zhu2023chatgpt, 110 | title={ChatGPT Asks, BLIP-2 Answers: Automatic Questioning Towards Enriched Visual Descriptions}, 111 | author={Zhu, Deyao and Chen, Jun and Haydarov, Kilichbek and Shen, Xiaoqian and Zhang, Wenxuan and Elhoseiny, Mohamed}, 112 | journal={arXiv preprint arXiv:2303.06594}, 113 | year={2023} 114 | } 115 | ``` 116 | -------------------------------------------------------------------------------- /ChatCaptioner/caption.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "9c785ab9-f2c8-47e3-95d9-49d0f07ddced", 7 | "metadata": { 8 | "pycharm": { 9 | "name": "#%%\n" 10 | }, 11 | "tags": [] 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "%load_ext autoreload\n", 16 | "%autoreload 2" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "f2fd5a82-3a53-43fa-862e-abc208b0e8b2", 23 | "metadata": { 24 | "pycharm": { 25 | "name": "#%%\n" 26 | }, 27 | "tags": [] 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "import yaml\n", 33 | "import torch\n", 34 | "\n", 35 | "from chatcaptioner.chat import set_openai_key, caption_images, get_instructions\n", 36 | "from chatcaptioner.blip2 import Blip2\n", 37 | "from chatcaptioner.utils import RandomSampledDataset, plot_img, print_info" 38 | ] 39 | }, 40 | { 41 | "cell_type": "markdown", 42 | "id": "75fa951a-bc3c-49e5-8654-70beeba10bad", 43 | "metadata": { 44 | "pycharm": { 45 | "name": "#%% md\n" 46 | } 47 | }, 48 | "source": [ 49 | "## Set OpenAI" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "84c5b12e-0c40-49b8-95c4-34b88a93240e", 56 | "metadata": { 57 | "pycharm": { 58 | "name": "#%%\n" 59 | }, 60 | "tags": [] 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "openai_key = os.environ[\"OPENAI_API_KEY\"]\n", 65 | "set_openai_key(openai_key)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "id": "4abe1586-de9f-47b5-943f-9bf5dc11a51e", 71 | "metadata": { 72 | "pycharm": { 73 | "name": "#%% md\n" 74 | } 75 | }, 76 | "source": [ 77 | "## Load BLIP-2" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "fbb5177b-8d76-425a-89f6-7b8c2fb182e1", 84 | "metadata": { 85 | "pycharm": { 86 | "name": "#%%\n" 87 | }, 88 | "tags": [] 89 | }, 90 | "outputs": [], 91 | "source": [ 92 | "blip2s = {\n", 93 | " 'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True), # load BLIP-2 FlanT5 XXL to GPU0. Too large, need 8 bit. About 20GB GPU Memory\n", 94 | " # 'OPT2.7B COCO': Blip2('OPT2.7B COCO', device_id=1, bit8=False), # load BLIP-2 OPT2.7B COCO to GPU1. About 10GB GPU Memory\n", 95 | " # 'OPT6.7B COCO': Blip2('OPT6.7B COCO', device_id=2, bit8=True), # load BLIP-2 OPT6.7B COCO to GPU2. Too large, need 8 bit.\n", 96 | "}\n", 97 | "blip2s_q = {}" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "1bb237d2-5fcf-4a99-8fdc-7b4de4cd3e94", 104 | "metadata": { 105 | "tags": [], 106 | "pycharm": { 107 | "name": "#%%\n" 108 | } 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "# blip2s_q = {\n", 113 | "# 'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True), # load BLIP-2 FlanT5 XXL to GPU0. Too large, need 8 bit. About 20GB GPU Memory\n", 114 | "# # 'OPT2.7B': Blip2('OPT2.7B', device_id=1, bit8=False), # load BLIP-2 OPT2.7B COCO to GPU1. About 10GB GPU Memory\n", 115 | "# # 'OPT6.7B': Blip2('OPT6.7B', device_id=2, bit8=True), # load BLIP-2 OPT6.7B COCO to GPU2. Too large, need 8 bit.\n", 116 | "# }\n", 117 | "# blip2s = {'FlanT5 XXL': blip2s_q['FlanT5 XXL']}" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "id": "3ab6f595-3257-4143-b2d7-5a90011acd06", 123 | "metadata": { 124 | "pycharm": { 125 | "name": "#%% md\n" 126 | } 127 | }, 128 | "source": [ 129 | "## Test Setting. Change it Accordingly" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": null, 135 | "id": "4386922c", 136 | "metadata": { 137 | "jupyter": { 138 | "outputs_hidden": false 139 | }, 140 | "pycharm": { 141 | "name": "#%%\n" 142 | } 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# set the dataset to test\n", 147 | "dataset_name = 'cc_val' # current options: 'artemis', 'cc_val', 'coco_val'\n", 148 | "# set the number of images you want to test\n", 149 | "n_test_img = 3\n", 150 | "# set the number of chat rounds between GPT3 and BLIP-2\n", 151 | "n_rounds = 10\n", 152 | "# set the number of visible chat rounds to BLIP-2. <0 means all the chat histories are visible.\n", 153 | "n_blip2_context = 1\n", 154 | "# if print the chat out in the testing\n", 155 | "print_chat = True\n", 156 | "# set the question model\n", 157 | "question_model_tag = 'gpt-3.5-turbo'" 158 | ] 159 | }, 160 | { 161 | "cell_type": "markdown", 162 | "id": "533e9ac1", 163 | "metadata": { 164 | "pycharm": { 165 | "name": "#%% md\n" 166 | } 167 | }, 168 | "source": [ 169 | "## Load Dataset & Prepare Foloder to Save Results" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": null, 175 | "id": "f6031873", 176 | "metadata": { 177 | "jupyter": { 178 | "outputs_hidden": false 179 | }, 180 | "pycharm": { 181 | "name": "#%%\n" 182 | } 183 | }, 184 | "outputs": [], 185 | "source": [ 186 | "# load the dataset\n", 187 | "DATA_ROOT = 'datasets/'\n", 188 | "dataset = RandomSampledDataset(DATA_ROOT, dataset_name)\n", 189 | "\n", 190 | "# preparing the folder to save results\n", 191 | "SAVE_PATH = 'experiments/0307_{}/{}'.format(question_model_tag, dataset_name)\n", 192 | "if not os.path.exists(SAVE_PATH):\n", 193 | " os.makedirs(os.path.join(SAVE_PATH, 'caption_result'))\n", 194 | "with open(os.path.join(SAVE_PATH, 'instruction.yaml'), 'w') as f:\n", 195 | " yaml.dump(get_instructions(), f)" 196 | ] 197 | }, 198 | { 199 | "cell_type": "markdown", 200 | "id": "257771df", 201 | "metadata": { 202 | "pycharm": { 203 | "name": "#%% md\n" 204 | } 205 | }, 206 | "source": [ 207 | "## Start Caption" 208 | ] 209 | }, 210 | { 211 | "cell_type": "code", 212 | "execution_count": null, 213 | "id": "48b33172", 214 | "metadata": { 215 | "jupyter": { 216 | "outputs_hidden": false 217 | }, 218 | "pycharm": { 219 | "name": "#%%\n" 220 | } 221 | }, 222 | "outputs": [], 223 | "source": [ 224 | "sample_img_ids = dataset.random_img_ids(n_test_img)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "3c3db5c8-690e-43d1-85d6-9b702691ce4b", 231 | "metadata": { 232 | "pycharm": { 233 | "name": "#%%\n" 234 | } 235 | }, 236 | "outputs": [], 237 | "source": [ 238 | "sample_img_ids = ['11627']\n", 239 | "if question_model_tag in blip2s_q:\n", 240 | " question_model = blip2s_q[question_model_tag]\n", 241 | "else:\n", 242 | " question_model = question_model_tag\n", 243 | "caption_images(blip2s, \n", 244 | " dataset, \n", 245 | " sample_img_ids, \n", 246 | " save_path=SAVE_PATH, \n", 247 | " n_rounds=n_rounds, \n", 248 | " n_blip2_context=n_blip2_context,\n", 249 | " model=question_model,\n", 250 | " print_mode='chat')" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "85e801c1-bb13-4a35-baf5-fabc1833d778", 257 | "metadata": { 258 | "pycharm": { 259 | "name": "#%%\n" 260 | } 261 | }, 262 | "outputs": [], 263 | "source": [] 264 | } 265 | ], 266 | "metadata": { 267 | "kernelspec": { 268 | "display_name": "chatae", 269 | "language": "python", 270 | "name": "chatae" 271 | }, 272 | "language_info": { 273 | "codemirror_mode": { 274 | "name": "ipython", 275 | "version": 3 276 | }, 277 | "file_extension": ".py", 278 | "mimetype": "text/x-python", 279 | "name": "python", 280 | "nbconvert_exporter": "python", 281 | "pygments_lexer": "ipython3", 282 | "version": "3.9.16" 283 | } 284 | }, 285 | "nbformat": 4, 286 | "nbformat_minor": 5 287 | } -------------------------------------------------------------------------------- /ChatCaptioner/chatcaptioner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/ChatCaptioner/chatcaptioner/__init__.py -------------------------------------------------------------------------------- /ChatCaptioner/chatcaptioner/blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 3 | 4 | 5 | BLIP2DICT = { 6 | 'FlanT5 XXL': 'Salesforce/blip2-flan-t5-xxl', 7 | 'FlanT5 XL COCO': 'Salesforce/blip2-flan-t5-xl-coco', 8 | 'OPT6.7B COCO': 'Salesforce/blip2-opt-6.7b-coco', 9 | 'OPT2.7B COCO': 'Salesforce/blip2-opt-2.7b-coco', 10 | 'FlanT5 XL': 'Salesforce/blip2-flan-t5-xl', 11 | 'OPT6.7B': 'Salesforce/blip2-opt-6.7b', 12 | 'OPT2.7B': 'Salesforce/blip2-opt-2.7b', 13 | } 14 | 15 | 16 | class Blip2(): 17 | def __init__(self, model, device_id, bit8=True): 18 | # load BLIP-2 to a single gpu 19 | self.tag = model 20 | self.bit8 = bit8 21 | self.device = 'cuda:{}'.format(device_id) 22 | 23 | dtype = {'load_in_8bit': True} if self.bit8 else {'torch_dtype': torch.float16} 24 | self.blip2_processor = Blip2Processor.from_pretrained(BLIP2DICT[self.tag]) 25 | self.blip2 = Blip2ForConditionalGeneration.from_pretrained(BLIP2DICT[self.tag], device_map={'': device_id}, **dtype) 26 | 27 | def ask(self, raw_image, question): 28 | inputs = self.blip2_processor(raw_image, question, return_tensors="pt").to(self.device, torch.float16) 29 | out = self.blip2.generate(**inputs) 30 | answer = self.blip2_processor.decode(out[0], skip_special_tokens=True) 31 | return answer 32 | 33 | def caption(self, raw_image): 34 | # starndard way to caption an image in the blip2 paper 35 | caption = self.ask(raw_image, 'a photo of') 36 | caption = caption.replace('\n', ' ').strip() # trim caption 37 | return caption 38 | 39 | def call_llm(self, prompts): 40 | prompts_temp = self.blip2_processor(None, prompts, return_tensors="pt") 41 | input_ids = prompts_temp['input_ids'].to(self.device) 42 | attention_mask = prompts_temp['attention_mask'].to(self.device, torch.float16) 43 | 44 | prompts_embeds = self.blip2.language_model.get_input_embeddings()(input_ids) 45 | 46 | outputs = self.blip2.language_model.generate( 47 | inputs_embeds=prompts_embeds, 48 | attention_mask=attention_mask) 49 | 50 | outputs = self.blip2_processor.decode(outputs[0], skip_special_tokens=True) 51 | 52 | return outputs 53 | -------------------------------------------------------------------------------- /ChatCaptioner/chatcaptioner/chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from tqdm import tqdm 4 | import torch 5 | import openai 6 | from tenacity import ( 7 | retry, 8 | stop_after_attempt, 9 | wait_random_exponential, 10 | ) # for exponential backoff 11 | import gradio as gr 12 | 13 | from chatcaptioner.blip2 import Blip2 14 | from chatcaptioner.utils import print_info, plot_img 15 | 16 | 17 | QUESTION_INSTRUCTION = \ 18 | "I have an image. " \ 19 | "Ask me questions about the content of this image. " \ 20 | "Carefully asking me informative questions to maximize your information about this image content. " \ 21 | "Each time ask one question only without giving an answer. " \ 22 | "Avoid asking yes/no questions." \ 23 | "I'll put my answer beginning with \"Answer:\"." \ 24 | 25 | 26 | SUB_QUESTION_INSTRUCTION = \ 27 | "Next Question. Avoid asking yes/no questions. \n" \ 28 | "Question: " 29 | 30 | 31 | SUMMARY_INSTRUCTION = \ 32 | 'Now summarize the information you get in a few sentences. ' \ 33 | 'Ignore the questions with answers no or not sure. ' \ 34 | 'Don\'t add information. Don\'t miss information. \n' \ 35 | 'Summary: ' 36 | 37 | 38 | ANSWER_INSTRUCTION = 'Answer given questions. If you are not sure about the answer, say you don\'t know honestly. Don\'t imagine any contents that are not in the image.' 39 | 40 | 41 | SUB_ANSWER_INSTRUCTION = 'Answer: ' # template following blip2 huggingface demo 42 | 43 | 44 | FIRST_QUESTION = 'Describe this image in detail.' 45 | 46 | 47 | VALID_CHATGPT_MODELS = ['gpt-3.5-turbo'] 48 | VALID_GPT3_MODELS = ['text-davinci-003', 'text-davinci-002', 'davinci'] 49 | 50 | 51 | 52 | def get_instructions(): 53 | instructions_dict = { 54 | 'question': QUESTION_INSTRUCTION, 55 | 'sub_question': SUB_QUESTION_INSTRUCTION, 56 | 'summary': SUMMARY_INSTRUCTION, 57 | 'answer': ANSWER_INSTRUCTION, 58 | 'sub_answer': SUB_ANSWER_INSTRUCTION, 59 | 'first_question': FIRST_QUESTION 60 | } 61 | return instructions_dict 62 | 63 | 64 | 65 | def set_openai_key(key): 66 | openai.api_key = key 67 | 68 | 69 | def get_chat_log(questions, answers, last_n=-1): 70 | n_addition_q = len(questions) - len(answers) 71 | assert (n_addition_q) in [0, 1] 72 | template = 'Question: {} \nAnswer: {} \n' 73 | chat_log = '' 74 | if last_n > 0: 75 | answers = answers[-last_n:] 76 | questions = questions[-(last_n+n_addition_q):] 77 | elif last_n == 0: 78 | answers = [] 79 | questions = questions[-1:] if n_addition_q else [] 80 | 81 | 82 | for i in range(len(answers)): 83 | chat_log = chat_log + template.format(questions[i], answers[i]) 84 | if n_addition_q: 85 | chat_log = chat_log + 'Question: {}'.format(questions[-1]) 86 | else: 87 | chat_log = chat_log[:-2] # remove the last '/n' 88 | return chat_log 89 | 90 | 91 | def prepare_gpt_prompt(task_prompt, questions, answers, sub_prompt): 92 | gpt_prompt = '\n'.join([task_prompt, 93 | get_chat_log(questions, answers), 94 | sub_prompt]) 95 | return gpt_prompt 96 | 97 | 98 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) 99 | def call_gpt3(gpt3_prompt, max_tokens=40, model="text-davinci-003"): # 'text-curie-001' does work at all to ask questions 100 | response = openai.Completion.create(model=model, prompt=gpt3_prompt, max_tokens=max_tokens) # temperature=0.6, 101 | reply = response['choices'][0]['text'] 102 | total_tokens = response['usage']['total_tokens'] 103 | return reply, total_tokens 104 | 105 | 106 | def prepare_chatgpt_message(task_prompt, questions, answers, sub_prompt): 107 | messages = [{"role": "system", "content": task_prompt}] 108 | 109 | assert len(questions) == len(answers) 110 | for q, a in zip(questions, answers): 111 | messages.append({'role': 'assistant', 'content': 'Question: {}'.format(q)}) 112 | messages.append({'role': 'user', 'content': 'Answer: {}'.format(a)}) 113 | messages.append({"role": "system", "content": sub_prompt}) 114 | 115 | return messages 116 | 117 | 118 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) 119 | def call_chatgpt(chatgpt_messages, max_tokens=40, model="gpt-3.5-turbo"): 120 | response = openai.ChatCompletion.create(model=model, messages=chatgpt_messages, temperature=0.6, max_tokens=max_tokens) 121 | reply = response['choices'][0]['message']['content'] 122 | total_tokens = response['usage']['total_tokens'] 123 | return reply, total_tokens 124 | 125 | 126 | class AskQuestions(): 127 | 128 | def __init__(self, img, blip2, model, max_gpt_token=30, n_blip2_context=-1): 129 | self.img = img 130 | self.blip2 = blip2 131 | self.model = model 132 | self.max_gpt_token = max_gpt_token 133 | self.n_blip2_context = n_blip2_context 134 | 135 | self.questions = [] 136 | self.answers = [] 137 | self.total_tokens = 0 138 | 139 | def reset(self, img): 140 | self.img = img 141 | self.questions = [] 142 | self.answers = [] 143 | self.total_tokens = 0 144 | 145 | def ask_question(self): 146 | if len(self.questions) == 0: 147 | # first question is given by human to request a general discription 148 | question = FIRST_QUESTION 149 | else: 150 | if self.model in VALID_CHATGPT_MODELS: 151 | chatgpt_messages = prepare_chatgpt_message( 152 | QUESTION_INSTRUCTION, 153 | self.questions, self.answers, 154 | SUB_QUESTION_INSTRUCTION 155 | ) 156 | question, n_tokens = call_chatgpt(chatgpt_messages, model=self.model, max_tokens=self.max_gpt_token) 157 | elif self.model in VALID_GPT3_MODELS: 158 | # prepare the context for GPT3 159 | gpt3_prompt = prepare_gpt_prompt( 160 | QUESTION_INSTRUCTION, 161 | self.questions, self.answers, 162 | SUB_QUESTION_INSTRUCTION 163 | ) 164 | 165 | question, n_tokens = call_gpt3(gpt3_prompt, model=self.model, max_tokens=self.max_gpt_token) 166 | elif isinstance(self.model, Blip2): 167 | # prepare the context for other LLM 168 | gpt_prompt = prepare_gpt_prompt( 169 | QUESTION_INSTRUCTION, 170 | self.questions, self.answers, 171 | SUB_QUESTION_INSTRUCTION 172 | ) 173 | n_tokens = 0 # local model. no token cost on OpenAI API. 174 | question = self.model.call_llm(gpt_prompt) 175 | else: 176 | raise ValueError('{} is not a valid question model'.format(self.model)) 177 | 178 | self.total_tokens = self.total_tokens + n_tokens 179 | 180 | return question 181 | 182 | def question_trim(self, question): 183 | question = question.split('Question: ')[-1].replace('\n', ' ').strip() 184 | if 'Answer:' in question: # Some models make up an answer after asking. remove it 185 | q, a = question.split('Answer:')[:2] 186 | if len(q) == 0: # some not so clever models will put the question after 'Answer:'. 187 | question = a.strip() 188 | else: 189 | question = q.strip() 190 | return question 191 | 192 | def answer_question(self): 193 | # prepare the context for blip2 194 | blip2_prompt = '\n'.join([ANSWER_INSTRUCTION, 195 | get_chat_log(self.questions, self.answers, last_n=self.n_blip2_context), 196 | SUB_ANSWER_INSTRUCTION]) 197 | 198 | answer = self.blip2.ask(self.img, blip2_prompt) 199 | return answer 200 | 201 | def answer_trim(self, answer): 202 | answer = answer.split('Question:')[0].replace('\n', ' ').strip() 203 | return answer 204 | 205 | def chatting(self, n_rounds, print_mode): 206 | if print_mode == 'chat': 207 | print('--------Chat Starts----------') 208 | 209 | for i in tqdm(range(n_rounds), desc='Chat Rounds', disable=print_mode != 'bar'): 210 | question = self.ask_question() 211 | # print('Raw: {}'.format(question)) 212 | question = self.question_trim(question) 213 | self.questions.append(question) 214 | 215 | if print_mode == 'chat': 216 | print('GPT-3: {}'.format(question)) 217 | elif print_mode == 'gradio': 218 | gr_chatbot = gr_chatbot + [[question, None]] 219 | 220 | answer = self.answer_question() 221 | answer = self.answer_trim(answer) 222 | self.answers.append(answer) 223 | 224 | if print_mode == 'chat': 225 | print('BLIP-2: {}'.format(answer)) 226 | elif print_mode == 'gradio': 227 | self.gr_chatbot[-1][1] = answer 228 | 229 | if print_mode == 'chat': 230 | print('--------Chat Ends----------') 231 | 232 | return self.questions, self.answers, self.total_tokens 233 | 234 | 235 | def summarize_chat(questions, answers, model, max_gpt_token=100): 236 | if model in VALID_GPT3_MODELS: 237 | summary_prompt = prepare_gpt_prompt( 238 | QUESTION_INSTRUCTION, 239 | questions, answers, 240 | SUMMARY_INSTRUCTION) 241 | 242 | summary, n_tokens = call_gpt3(summary_prompt, model=model, max_tokens=max_gpt_token) 243 | elif model in VALID_CHATGPT_MODELS: 244 | summary_prompt = prepare_chatgpt_message( 245 | QUESTION_INSTRUCTION, 246 | questions, answers, 247 | SUMMARY_INSTRUCTION 248 | ) 249 | summary, n_tokens = call_chatgpt(summary_prompt, model=model, max_tokens=max_gpt_token) 250 | elif isinstance(model, Blip2): 251 | summary_prompt = prepare_gpt_prompt( 252 | QUESTION_INSTRUCTION, 253 | questions, answers, 254 | SUMMARY_INSTRUCTION 255 | ) 256 | n_tokens = 0 # local model. no token cost on OpenAI API. 257 | summary = model.call_llm(summary_prompt) 258 | else: 259 | raise ValueError('{} is not a valid question model'.format(model)) 260 | 261 | summary = summary.replace('\n', ' ').strip() 262 | return summary, summary_prompt, n_tokens 263 | 264 | 265 | def caption_image(blip2, image, model, n_rounds=10, n_blip2_context=-1, print_mode='no'): 266 | if model == 'gpt3': 267 | model = 'text-davinci-003' 268 | elif model == 'chatgpt': 269 | model = 'gpt-3.5-turbo' 270 | 271 | results = {} 272 | chat = AskQuestions(image, 273 | blip2, 274 | n_blip2_context=n_blip2_context, 275 | model=model) 276 | 277 | questions, answers, n_token_chat = chat.chatting(n_rounds, print_mode=print_mode) 278 | 279 | summary, summary_prompt, n_token_sum = summarize_chat(questions, answers, model=model) 280 | results['ChatCaptioner'] = {'caption': summary, 'chat': summary_prompt, 'n_token': n_token_chat + n_token_sum} 281 | results['BLIP2+OurPrompt'] = {'caption': answers[0]} 282 | 283 | # Default BLIP2 caption 284 | caption = blip2.caption(image) 285 | results['BLIP2'] = {'caption': caption} 286 | 287 | return results 288 | 289 | 290 | def caption_images(blip2s, dataset, img_ids, model, save_path='', n_rounds=10, n_blip2_context=-1, print_mode='no'): 291 | """ 292 | Caption images with a set of blip2 models 293 | 294 | Args: 295 | blip2s (dict): A dict of blip2 models. Key is the blip2 model name 296 | dataset: the dataset used to caption 297 | img_ids (list): a list of image ids in the dataset used to caption 298 | model (str or Blip2): the model name used to ask quetion. Valid values are 'gpt3', 'chatgpt', and their concrete model names 299 | including 'text-davinci-003', 'davinci,' and 'gpt-3.5-turbo'. 300 | If passing a Blip2 instance, will use its backend LLM. 301 | save_path (str): the path to save caption results. If it is empty, results are not being saved. 302 | n_rounds (int): the number of chat rounds 303 | n_blip2_context (int): how many previous QA rounds can blip2 see. negative value means blip2 can see all 304 | print_mode (str): print mode. 'chat' for printing everying. 'bar' for printing everthing but the chat process. 'no' for no printing 305 | """ 306 | if model == 'gpt3': 307 | model = 'text-davinci-003' 308 | elif model == 'chatgpt': 309 | model = 'gpt-3.5-turbo' 310 | 311 | for img_id in tqdm(img_ids, disable=print_mode!='no'): 312 | caption_path = os.path.join(save_path, 'caption_result', '{}.yaml'.format(img_id)) 313 | if os.path.exists(caption_path): 314 | continue 315 | if print_mode != 'no': 316 | print('Image ID {}'.format(img_id)) 317 | 318 | image, gt_captions = dataset.fetch_img(img_id) 319 | info = {'setting': 320 | {'dataset': dataset.name, 321 | 'id': img_id, 322 | 'GT': {'caption': [caption.replace('\n', ' ').strip() for caption in gt_captions]}, 323 | 'n_rounds': n_rounds 324 | } 325 | } 326 | 327 | for blip2_tag, blip2 in blip2s.items(): 328 | info[blip2_tag] = caption_image(blip2, 329 | image, 330 | n_rounds=n_rounds, 331 | n_blip2_context=n_blip2_context, 332 | model=model, 333 | print_mode=print_mode) 334 | 335 | if print_mode != 'no': 336 | print_info(info) 337 | plot_img(image) 338 | 339 | if save_path: 340 | with open(caption_path, 'w') as f: 341 | yaml.dump(info, f) -------------------------------------------------------------------------------- /ChatCaptioner/chatcaptioner/clip.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | 3 | class ClipScore(): 4 | def __init__(self, device='cuda:0'): 5 | # load open clip to device 6 | self.device = device 7 | clip, _, self.clip_preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') 8 | self.clip = clip.to(self.device) 9 | self.clip_tokenizer = open_clip.get_tokenizer('ViT-H-14') 10 | 11 | 12 | def clip_IT_score(self, image, texts): 13 | ''' 14 | compute the average similarity score of a given image and a list of texts 15 | ''' 16 | if isinstance(texts, str): 17 | texts = [texts] 18 | image = self.clip_preprocess(image)[None].to(self.device) 19 | texts = self.clip_tokenizer(texts).to(self.device) 20 | with torch.no_grad(): 21 | image_f = self.clip.encode_image(image).float() 22 | texts_f = self.clip.encode_text(texts).float() 23 | image_f /= image_f.norm(dim=-1, keepdim=True) 24 | texts_f /= texts_f.norm(dim=-1, keepdim=True) 25 | similarity = (image_f.cpu().numpy() @ texts_f.cpu().numpy().T).mean() 26 | similarity = round(float(similarity), 3) 27 | return similarity -------------------------------------------------------------------------------- /ChatCaptioner/chatcaptioner/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | import yaml 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | from pycocotools.coco import COCO 8 | 9 | 10 | 11 | class COCOHelper(): 12 | def __init__(self, coco_path, coco_ann_path): 13 | # todo: make it works for test set. test set doesn't contain annotation 14 | self.coco_path = coco_path 15 | self.coco_ann = COCO(annotation_file=coco_ann_path) 16 | self.coco_ids = self.coco_ann.getImgIds() 17 | # self.split = split 18 | 19 | def random_img_ids(self, n): 20 | sample_img_ids = random.sample(self.coco_ids, n) 21 | return sample_img_ids 22 | 23 | def fetch_coco_img(self, image_id, split='val'): 24 | img_name = '%012d.jpg' % image_id 25 | img_path = os.path.join(self.coco_path, img_name) 26 | raw_image = Image.open(img_path).convert('RGB') 27 | 28 | ann_ids = self.coco_ann.getAnnIds(imgIds=[image_id], iscrowd=None) 29 | anns = self.coco_ann.loadAnns(ann_ids) 30 | captions = [ann['caption'] for ann in anns] 31 | return raw_image, captions 32 | 33 | 34 | class RandomSampledDataset(): 35 | def __init__(self, datasets_root, dataset_name): 36 | self.name = dataset_name 37 | self.dataset_path = os.path.join(datasets_root, dataset_name) 38 | self._ids = [file_name.split('.jpg')[0] for file_name in os.listdir(os.path.join(self.dataset_path, 'img'))] 39 | 40 | 41 | ann_path = os.path.join(datasets_root, dataset_name, 'annotation.yaml') 42 | if os.path.exists(ann_path): 43 | with open(ann_path, 'r') as f: 44 | self.ann = yaml.safe_load(f) 45 | if isinstance(list(self.ann.keys())[0], int): 46 | self.ann = {str(image_id): captions for image_id, captions in self.ann.items()} 47 | else: 48 | self.ann = None 49 | 50 | @property 51 | def ids(self): 52 | return deepcopy(self._ids) 53 | 54 | def random_img_ids(self, n): 55 | sample_img_ids = random.sample(self._ids, n) 56 | return sample_img_ids 57 | 58 | def fetch_img(self, image_id): 59 | img_path = os.path.join(self.dataset_path, 'img', '{}.jpg'.format(image_id)) 60 | raw_image = Image.open(img_path).convert('RGB') 61 | 62 | if self.ann: 63 | captions = self.ann[image_id] 64 | 65 | if isinstance(captions, str): 66 | captions = [captions] 67 | else: 68 | captions = [] 69 | 70 | return raw_image, captions 71 | 72 | 73 | class SimPairDataset(): 74 | def __init__(self, datasets_root, dataset_name): 75 | self.name = dataset_name 76 | self.dataset_path = os.path.join(datasets_root, dataset_name) 77 | 78 | ann_path = os.path.join(datasets_root, dataset_name, 'sim_retrieve.yaml') 79 | if os.path.exists(ann_path): 80 | with open(ann_path, 'r') as f: 81 | self.ann = yaml.safe_load(f) 82 | if isinstance(list(self.ann.keys())[0], int): 83 | self.ann = {str(image_id): captions for image_id, captions in self.ann.items()} 84 | else: 85 | self.ann = None 86 | self._ids = list(self.ann.keys()) 87 | 88 | @property 89 | def ids(self): 90 | return deepcopy(self._ids) 91 | 92 | def fetch_img_pairs(self, pair_id): 93 | image_ids = list(self.ann[pair_id].keys()) 94 | fetched = [] 95 | for image_id in image_ids: 96 | img_path = os.path.join(self.dataset_path, 'img', '{}.jpg'.format(image_id)) 97 | raw_image = Image.open(img_path).convert('RGB') 98 | if self.ann: 99 | captions = self.ann[pair_id][image_id] 100 | 101 | if isinstance(captions, str): 102 | captions = [captions] 103 | else: 104 | captions = [] 105 | fetched.append((image_id, raw_image, captions)) 106 | return fetched 107 | 108 | 109 | def extractQA_chatgpt(messages): 110 | questions = [] 111 | answers = [] 112 | for message in messages: 113 | if 'Question: ' in message['content']: 114 | questions.append(message['content'].split('Question: ')[1]) 115 | if 'Answer: ' in message['content']: 116 | answers.append(message['content'].split('Answer: ')[1]) 117 | return questions, answers 118 | 119 | 120 | def print_info(info, key='caption', variants=['BLIP2', 'BLIP2+OurPrompt', 'ChatCaptioner']): 121 | img_id = info['setting']['id'] 122 | if 'GT' in info['setting']: 123 | gt_captions = info['setting']['GT']['caption'] 124 | if isinstance(gt_captions, str) and len(gt_captions): 125 | gt_captions = [gt_captions] 126 | 127 | else: 128 | gt_captions = [] 129 | 130 | print('Image ID {}'.format(img_id)) 131 | for blip2_tag in info: 132 | if blip2_tag in ['GT', 'id', 'setting']: continue 133 | for variant in variants: 134 | if key not in info[blip2_tag][variant]: 135 | continue 136 | print('-------------------') 137 | print('{} {}:'.format(blip2_tag, variant)) 138 | if key == 'chat' and isinstance(info[blip2_tag][variant][key], list): 139 | for message in info[blip2_tag][variant][key]: 140 | print(message['content']) 141 | else: 142 | print(info[blip2_tag][variant][key]) 143 | if key == 'chat': 144 | print(info[blip2_tag][variant]['caption']) 145 | print('===================') 146 | if key == 'caption' and len(gt_captions): 147 | print('GT:') 148 | [print(cap) for cap in gt_captions] 149 | 150 | 151 | def plot_img(img): 152 | plt.imshow(img) 153 | plt.axis('off') 154 | plt.show() 155 | 156 | 157 | 158 | # =================================== 159 | # Deprecated Zone 160 | # =================================== 161 | 162 | def visualize_old(file_path): 163 | # out of date 164 | with open(file_path, 'r') as f: 165 | info = yaml.safe_load(f) 166 | print('COCO Val Image ID {}'.format(info['id'])) 167 | print('-------------------') 168 | print('Ours: {}'.format(info['ours']['clip_score'])) 169 | print(info['ours']['chat']) 170 | print('-------------------') 171 | print('BLIP2: {}'.format(info['blip2']['clip_score'])) 172 | print(info['blip2']['caption']) 173 | print('-------------------') 174 | print('GT: {}'.format(info['gt']['clip_score'])) 175 | [print(cap) for cap in info['gt']['caption']] 176 | image, _ = fetch_coco_img(info['id']) 177 | plot_img(image) 178 | 179 | 180 | def print_info_old(info, key='caption', variants=['BLIP2', 'BLIP2+OurPrompt', 'ChatCaptioner']): 181 | if 'id' in info: 182 | img_id = info['id'] 183 | else: 184 | img_id = info['setting']['id'] 185 | if 'GT' in info: 186 | gt_captions = info['GT']['caption'] 187 | elif 'GT' in info['setting']: 188 | gt_captions = info['setting']['GT']['caption'] 189 | else: 190 | gt_captions = [] 191 | 192 | print('Image ID {}'.format(img_id)) 193 | for blip2_tag in info: 194 | if blip2_tag in ['GT', 'id', 'setting']: continue 195 | for variant in variants: 196 | if key not in info[blip2_tag][variant]: 197 | continue 198 | print('-------------------') 199 | print('{} {}:'.format(blip2_tag, variant)) 200 | print(info[blip2_tag][variant][key]) 201 | print('===================') 202 | if key == 'caption' and len(gt_captions): 203 | print('GT:') 204 | [print(cap) for cap in gt_captions] -------------------------------------------------------------------------------- /ChatCaptioner/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | 4 | from chatcaptioner.chat import set_openai_key, summarize_chat, AskQuestions 5 | from chatcaptioner.blip2 import Blip2 6 | 7 | 8 | openai_key = os.environ["OPENAI_API_KEY"] 9 | set_openai_key(openai_key) 10 | 11 | 12 | blip2 = Blip2('FlanT5 XXL', device_id=0, bit8=True) 13 | chat = AskQuestions(None, blip2, 'gpt-3.5-turbo', n_blip2_context=1) 14 | 15 | 16 | def gradio_reset(gr_img): 17 | chat.reset(gr_img) 18 | return None 19 | 20 | 21 | def gradio_ask(chatbot): 22 | question = chat.ask_question() 23 | question = chat.question_trim(question) 24 | chat.questions.append(question) 25 | chatbot = chatbot + [[None, question]] 26 | return chatbot 27 | 28 | 29 | def gradio_answer(chatbot): 30 | answer = chat.answer_question() 31 | answer = chat.answer_trim(answer) 32 | chat.answers.append(answer) 33 | chatbot = chatbot + [[answer, None]] 34 | return chatbot 35 | 36 | 37 | def gradio_summarize(chatbot): 38 | summary, summary_prompt, n_token_sum = summarize_chat(chat.questions, chat.answers, model='gpt-3.5-turbo') 39 | chatbot = chatbot + [[None, 'Final Caption: ' + summary]] 40 | return chatbot 41 | 42 | 43 | with gr.Blocks() as demo: 44 | gr.Markdown("## ChatCaptioner Demo") 45 | with gr.Row(): 46 | with gr.Column(): 47 | image = gr.Image(type="pil") 48 | start = gr.Button("Start Chat") 49 | chatbot = gr.Chatbot() 50 | 51 | start.click(gradio_reset, image, chatbot) \ 52 | .then(gradio_ask, chatbot, chatbot) \ 53 | .then(gradio_answer, chatbot, chatbot) \ 54 | .then(gradio_ask, chatbot, chatbot) \ 55 | .then(gradio_answer, chatbot, chatbot) \ 56 | .then(gradio_ask, chatbot, chatbot) \ 57 | .then(gradio_answer, chatbot, chatbot) \ 58 | .then(gradio_ask, chatbot, chatbot) \ 59 | .then(gradio_answer, chatbot, chatbot) \ 60 | .then(gradio_ask, chatbot, chatbot) \ 61 | .then(gradio_answer, chatbot, chatbot) \ 62 | .then(gradio_ask, chatbot, chatbot) \ 63 | .then(gradio_answer, chatbot, chatbot) \ 64 | .then(gradio_ask, chatbot, chatbot) \ 65 | .then(gradio_answer, chatbot, chatbot) \ 66 | .then(gradio_ask, chatbot, chatbot) \ 67 | .then(gradio_answer, chatbot, chatbot) \ 68 | .then(gradio_ask, chatbot, chatbot) \ 69 | .then(gradio_answer, chatbot, chatbot) \ 70 | .then(gradio_ask, chatbot, chatbot) \ 71 | .then(gradio_answer, chatbot, chatbot) \ 72 | .then(gradio_summarize, chatbot, chatbot) 73 | 74 | demo.launch() -------------------------------------------------------------------------------- /ChatCaptioner/demo_pic/CuteCloud_1366x768.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/ChatCaptioner/demo_pic/CuteCloud_1366x768.jpg -------------------------------------------------------------------------------- /ChatCaptioner/demo_pic/demo1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/ChatCaptioner/demo_pic/demo1.gif -------------------------------------------------------------------------------- /ChatCaptioner/demo_pic/demo2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/ChatCaptioner/demo_pic/demo2.gif -------------------------------------------------------------------------------- /ChatCaptioner/demo_pic/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/ChatCaptioner/demo_pic/overview.png -------------------------------------------------------------------------------- /ChatCaptioner/draw_quali.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "a6381f4c-3db3-4cd2-92b2-07a2dbb8c0bb", 7 | "metadata": { 8 | "tags": [], 9 | "pycharm": { 10 | "name": "#%%\n" 11 | } 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "%load_ext autoreload\n", 16 | "%autoreload 2" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "d535f155-9f76-4b16-a3c2-585e8adf00a9", 23 | "metadata": { 24 | "tags": [], 25 | "pycharm": { 26 | "name": "#%%\n" 27 | } 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "from glob import glob\n", 33 | "import yaml\n", 34 | "\n", 35 | "import matplotlib\n", 36 | "import matplotlib.pyplot as plt\n", 37 | "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n", 38 | "\n", 39 | "from chatcaptioner.utils import print_info, plot_img, extractQA_chatgpt, RandomSampledDataset" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "id": "2e9e7b2d-b8e2-4216-ae25-0e388d6631ac", 46 | "metadata": { 47 | "tags": [], 48 | "pycharm": { 49 | "name": "#%%\n" 50 | } 51 | }, 52 | "outputs": [], 53 | "source": [ 54 | "def split_sentence(sentence, max_len=38):\n", 55 | " if len(sentence) < max_len:\n", 56 | " return sentence, 1\n", 57 | " words = sentence.split(' ')\n", 58 | " sub_sentence_list = []\n", 59 | " init = ''\n", 60 | " \n", 61 | " for word in words:\n", 62 | " tmp_init = init + ' ' + word\n", 63 | " if len(tmp_init) > max_len:\n", 64 | " sub_sentence_list.append(init)\n", 65 | " init = word\n", 66 | " else:\n", 67 | " init = tmp_init\n", 68 | " sub_sentence_list.append(init)\n", 69 | " \n", 70 | " return '\\n'.join(sub_sentence_list), len(sub_sentence_list)\n", 71 | " \n", 72 | "\n", 73 | "def plot_dialogue(lefts, rights, xs=[0.1, 0.7], init_y=1, y_gap=0.07, line_h=0.045):\n", 74 | " cdict = {'left': '#ecf5e6', 'right': '#e7f0fd'}\n", 75 | " \n", 76 | " def plot_text(x, y, s, pos):\n", 77 | " plt.text(\n", 78 | " x=x, y=y, s=s, \n", 79 | " horizontalalignment=pos,\n", 80 | " multialignment='left',\n", 81 | " verticalalignment='top',\n", 82 | " bbox=dict(boxstyle='round', \n", 83 | " fc=cdict[pos], \n", 84 | " ec=cdict[pos], \n", 85 | " ))\n", 86 | " \n", 87 | " cur_y = init_y\n", 88 | " for l, r in zip(lefts, rights):\n", 89 | " l, n_lines = split_sentence(l)\n", 90 | " plot_text(x=xs[0], y=cur_y, s=l, pos='left')\n", 91 | " cur_y -= y_gap + line_h * (n_lines-1)\n", 92 | " \n", 93 | " r, n_lines = split_sentence(r)\n", 94 | " plot_text(x=xs[1], y=cur_y, s=r, pos='right')\n", 95 | " cur_y -= y_gap + line_h * (n_lines-1)\n", 96 | " \n", 97 | " return cur_y\n", 98 | " \n", 99 | "def plot_summary(summary, x, y, max_len=43):\n", 100 | " summary, n_lines = split_sentence(summary, max_len)\n", 101 | " plt.text(\n", 102 | " x=x, y=y, s=summary, \n", 103 | " horizontalalignment='center',\n", 104 | " multialignment='left',\n", 105 | " verticalalignment='top',\n", 106 | " bbox=dict(boxstyle='round', \n", 107 | " fc='#ffe5b5', \n", 108 | " ec='#ffe5b5', \n", 109 | " ))\n", 110 | "\n", 111 | " \n", 112 | "def fancy_plot(img, questions, answers, summary, xs=[0, 1], init_y=1):\n", 113 | " ax = plt.gca()\n", 114 | " w, h = test_img.size\n", 115 | " img = img.resize([int(256/h*w), 256])\n", 116 | " # plt.xlim(*xs)\n", 117 | " \n", 118 | " imagebox = OffsetImage(img, zoom=0.5)\n", 119 | " ab = AnnotationBbox(imagebox, ((xs[1] + xs[0]) / 2, init_y), frameon=False, box_alignment=(0.5, 0))\n", 120 | " ax.add_artist(ab)\n", 121 | " \n", 122 | " y = init_y - 0.03\n", 123 | " y = plot_dialogue(questions, answers, xs=xs, init_y=y)\n", 124 | " \n", 125 | " y = y - 0.01\n", 126 | " plot_summary(summary, (xs[1] + xs[0]) / 2, y)\n", 127 | " \n", 128 | " plt.axis('off')\n", 129 | " \n", 130 | " \n", 131 | " " 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "80751efa-cd79-481e-81ae-deda51298eee", 138 | "metadata": { 139 | "tags": [], 140 | "pycharm": { 141 | "name": "#%%\n" 142 | } 143 | }, 144 | "outputs": [], 145 | "source": [ 146 | "# specify SAVE_PATH to visualize the result you want\n", 147 | "SAVE_PATH = 'experiments/test/'\n", 148 | "DATA_ROOT = 'datasets/'" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "030253f3-6db4-4b76-b8f0-07bc412ce3d0", 155 | "metadata": { 156 | "tags": [], 157 | "pycharm": { 158 | "name": "#%%\n" 159 | } 160 | }, 161 | "outputs": [], 162 | "source": [ 163 | "datasets_list = os.listdir(SAVE_PATH)\n", 164 | "datasets_list = ['artemis', 'coco_val']\n", 165 | "for dataset_name in datasets_list:\n", 166 | " print('============================')\n", 167 | " print(' {} '.format(dataset_name))\n", 168 | " print('============================')\n", 169 | " fig_path = 'figs/testV4_chatgpt/{}'.format(dataset_name)\n", 170 | " os.makedirs(fig_path, exist_ok=True)\n", 171 | " \n", 172 | " dataset = RandomSampledDataset(DATA_ROOT, dataset_name)\n", 173 | " \n", 174 | " save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))\n", 175 | " for info_file in save_infos:\n", 176 | " with open(info_file, 'r') as f:\n", 177 | " info = yaml.safe_load(f)\n", 178 | " \n", 179 | " \n", 180 | " img_id = info['id'] if 'id' in info else info['setting']['id']\n", 181 | " test_img, _ = dataset.fetch_img(img_id)\n", 182 | " \n", 183 | " questions, answers = extractQA_chatgpt(info['FlanT5 XXL']['ChatCaptioner']['chat'])\n", 184 | " summary = info['FlanT5 XXL']['ChatCaptioner']['caption']\n", 185 | " fancy_plot(test_img, questions, answers, summary)\n", 186 | " plt.gca().set_aspect(1.3)\n", 187 | " plt.savefig(os.path.join(fig_path, '{}.pdf'.format(img_id)), bbox_inches='tight')\n", 188 | " plt.close()" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "520912d5-eece-4184-a055-73349cfe419d", 195 | "metadata": { 196 | "pycharm": { 197 | "name": "#%%\n" 198 | } 199 | }, 200 | "outputs": [], 201 | "source": [] 202 | } 203 | ], 204 | "metadata": { 205 | "kernelspec": { 206 | "display_name": "chatae", 207 | "language": "python", 208 | "name": "chatae" 209 | }, 210 | "language_info": { 211 | "codemirror_mode": { 212 | "name": "ipython", 213 | "version": 3 214 | }, 215 | "file_extension": ".py", 216 | "mimetype": "text/x-python", 217 | "name": "python", 218 | "nbconvert_exporter": "python", 219 | "pygments_lexer": "ipython3", 220 | "version": "3.9.16" 221 | } 222 | }, 223 | "nbformat": 4, 224 | "nbformat_minor": 5 225 | } -------------------------------------------------------------------------------- /ChatCaptioner/environment.yml: -------------------------------------------------------------------------------- 1 | name: chatcap 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - pip 8 | - pytorch=1.12.1 9 | - pytorch-mutex=1.0=cuda 10 | - torchaudio=0.12.1 11 | - torchvision=0.13.1 12 | - pip: 13 | - accelerate==0.16.0 14 | - aiohttp==3.8.4 15 | - aiosignal==1.3.1 16 | - async-timeout==4.0.2 17 | - attrs==22.2.0 18 | - bitsandbytes==0.37.0 19 | - cchardet==2.1.7 20 | - chardet==5.1.0 21 | - contourpy==1.0.7 22 | - cycler==0.11.0 23 | - filelock==3.9.0 24 | - fonttools==4.38.0 25 | - frozenlist==1.3.3 26 | - huggingface-hub==0.12.1 27 | - importlib-resources==5.12.0 28 | - kiwisolver==1.4.4 29 | - matplotlib==3.7.0 30 | - multidict==6.0.4 31 | - openai==0.27.0 32 | - packaging==23.0 33 | - psutil==5.9.4 34 | - pycocotools==2.0.6 35 | - pyparsing==3.0.9 36 | - python-dateutil==2.8.2 37 | - pyyaml==6.0 38 | - regex==2022.10.31 39 | - tokenizers==0.13.2 40 | - tqdm==4.64.1 41 | - transformers==4.27.4 42 | - yarl==1.8.2 43 | - zipp==3.14.0 44 | - tenacity==8.2.2 45 | - pycocoevalcap 46 | - sentence-transformers 47 | - umap-learn 48 | - notebook 49 | - gradio 50 | -------------------------------------------------------------------------------- /ChatCaptioner/main_caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | import torch 5 | 6 | from chatcaptioner.chat import set_openai_key, caption_images, get_instructions 7 | from chatcaptioner.blip2 import Blip2 8 | from chatcaptioner.utils import RandomSampledDataset, plot_img, print_info 9 | 10 | 11 | def parse(): 12 | parser = argparse.ArgumentParser(description='Generating captions in test datasets.') 13 | parser.add_argument('--data_root', type=str, default='datasets/', 14 | help='root path to the datasets') 15 | parser.add_argument('--save_root', type=str, default='experiments/', 16 | help='root path for saving results') 17 | parser.add_argument('--exp_tag', type=str, required=True, 18 | help='tag for this experiment. caption results will be saved in save_root/exp_tag') 19 | parser.add_argument('--datasets', nargs='+', choices=['artemis', 'cc_val', 'coco_val', 'para_test', 'pascal'], default=['coco_val'], 20 | help='Names of the datasets to use in the experiment. Valid datasets include artemis, cc_val, coco_val. Default is coco_val') 21 | parser.add_argument('--n_rounds', type=int, default=10, 22 | help='Number of QA rounds between GPT3 and BLIP-2. Default is 10, which costs about 2k tokens in GPT3 API.') 23 | parser.add_argument('--n_blip2_context', type=int, default=1, 24 | help='Number of QA rounds visible to BLIP-2. Default is 1, which means BLIP-2 only remember one previous question. -1 means BLIP-2 can see all the QA rounds') 25 | parser.add_argument('--model', type=str, default='chatgpt', choices=['gpt3', 'chatgpt', 'text-davinci-003', 'text-davinci-002', 'davinci', 'gpt-3.5-turbo', 'FlanT5XXL', 'OPT'], 26 | help='model used to ask question. can be gpt3, chatgpt, or its concrete tags in openai system') 27 | parser.add_argument('--device_id', type=int, default=0, 28 | help='Which GPU to use.') 29 | 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def main(args): 35 | # Set OpenAI 36 | openai_key = os.environ["OPENAI_API_KEY"] 37 | set_openai_key(openai_key) 38 | 39 | # Load BLIP-2 40 | blip2s = { 41 | 'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=args.device_id, bit8=True), # load BLIP-2 FlanT5 XXL to GPU0. Too large, need 8 bit. About 20GB GPU Memory 42 | } 43 | 44 | if args.model == 'FlanT5XXL': 45 | question_model = blip2s['FlanT5 XXL'] 46 | elif args.model == 'OPT': 47 | question_model = Blip2('OPT6.7B', device_id=2, bit8=True) 48 | else: 49 | question_model = args.model 50 | 51 | 52 | for dataset_name in args.datasets: 53 | # load the dataset 54 | dataset = RandomSampledDataset(args.data_root, dataset_name) 55 | # preparing the folder to save results 56 | save_path = os.path.join(args.save_root, args.exp_tag, dataset_name) 57 | if not os.path.exists(save_path): 58 | os.makedirs(os.path.join(save_path, 'caption_result')) 59 | with open(os.path.join(save_path, 'instruction.yaml'), 'w') as f: 60 | yaml.dump(get_instructions(), f) 61 | 62 | # start caption 63 | caption_images(blip2s, 64 | dataset, 65 | dataset.ids, 66 | save_path=save_path, 67 | n_rounds=args.n_rounds, 68 | n_blip2_context=args.n_blip2_context, 69 | model=question_model, 70 | print_mode='no') 71 | 72 | 73 | if __name__ == '__main__': 74 | args = parse() 75 | main(args) -------------------------------------------------------------------------------- /ChatCaptioner/not_sure.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "5c44a676-0233-4475-bb98-a81fac693899", 7 | "metadata": { 8 | "tags": [], 9 | "pycharm": { 10 | "name": "#%%\n" 11 | } 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "%load_ext autoreload\n", 16 | "%autoreload 2" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "03947bd6-59ca-4b9c-a30b-33dbeea8fc54", 23 | "metadata": { 24 | "tags": [], 25 | "pycharm": { 26 | "name": "#%%\n" 27 | } 28 | }, 29 | "outputs": [], 30 | "source": [ 31 | "import os\n", 32 | "from glob import glob\n", 33 | "import csv\n", 34 | "import yaml\n", 35 | "from chatcaptioner.chat import get_chat_log\n", 36 | "from chatcaptioner.blip2 import Blip2\n", 37 | "from chatcaptioner.utils import print_info, plot_img, extractQA_chatgpt, RandomSampledDataset" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "id": "27f4088d-3886-43e0-9dc9-0a8b9b946fae", 44 | "metadata": { 45 | "tags": [], 46 | "pycharm": { 47 | "name": "#%%\n" 48 | } 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "# specify SAVE_PATH to visualize the result you want\n", 53 | "SAVE_PATH = 'experiments/testV4_chatgpt/'\n", 54 | "DATA_ROOT = 'datasets/'" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "a16026ee-aa9b-436e-be6d-d4cc48172214", 61 | "metadata": { 62 | "tags": [], 63 | "pycharm": { 64 | "name": "#%%\n" 65 | } 66 | }, 67 | "outputs": [], 68 | "source": [ 69 | "blip2 = Blip2('FlanT5 XXL', device_id=0, bit8=True)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "eba38f95-3768-4b26-8881-e5693c65a13e", 76 | "metadata": { 77 | "jupyter": { 78 | "outputs_hidden": true 79 | }, 80 | "tags": [], 81 | "pycharm": { 82 | "name": "#%%\n" 83 | } 84 | }, 85 | "outputs": [], 86 | "source": [ 87 | "datasets_list = os.listdir(SAVE_PATH)\n", 88 | "datasets_list = ['cc_val']\n", 89 | "uncertainty_list = []\n", 90 | "\n", 91 | "for dataset_name in datasets_list:\n", 92 | " print('============================')\n", 93 | " print(' {} '.format(dataset_name))\n", 94 | " print('============================')\n", 95 | " dataset = RandomSampledDataset(DATA_ROOT, dataset_name)\n", 96 | " \n", 97 | " save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))\n", 98 | " for info_file in save_infos:\n", 99 | " with open(info_file, 'r') as f:\n", 100 | " info = yaml.safe_load(f)\n", 101 | " img_id = info['id'] if 'id' in info else info['setting']['id']\n", 102 | " test_img, _ = dataset.fetch_img(img_id)\n", 103 | " \n", 104 | " chat = info['FlanT5 XXL']['ChatCaptioner']['chat']\n", 105 | " questions, answers = extractQA_chatgpt(chat)\n", 106 | " not_sure = False\n", 107 | " for q, a in zip(questions, answers):\n", 108 | " if 'sure' in a or 'know' in a:\n", 109 | " not_sure = True\n", 110 | " print('Question: {}'.format(q))\n", 111 | " print('Answer: {}'.format(a))\n", 112 | " uncertainty_list.append((img_id, q, a))\n", 113 | " if not_sure:\n", 114 | " plot_img(test_img)\n", 115 | " " 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "b717d01e-fbfc-44b3-8690-542d417a8c26", 122 | "metadata": { 123 | "tags": [], 124 | "pycharm": { 125 | "name": "#%%\n" 126 | } 127 | }, 128 | "outputs": [], 129 | "source": [ 130 | "uncertainty_dict = {}\n", 131 | "for img_id, q, a in uncertainty_list:\n", 132 | " if img_id not in uncertainty_dict:\n", 133 | " uncertainty_dict[img_id] = [q]\n", 134 | " else:\n", 135 | " uncertainty_dict[img_id].append(q)\n", 136 | "with open(os.path.join('not_sure.yaml'), 'w') as f:\n", 137 | " yaml.dump(uncertainty_dict, f)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": null, 143 | "id": "422228e3-700c-43cf-a857-a9a4871ce714", 144 | "metadata": { 145 | "tags": [], 146 | "pycharm": { 147 | "name": "#%%\n" 148 | } 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "uncertainty_dict.keys()" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "bb32c46f-1f94-4b01-8ea5-92f72619d975", 159 | "metadata": { 160 | "tags": [], 161 | "pycharm": { 162 | "name": "#%%\n" 163 | } 164 | }, 165 | "outputs": [], 166 | "source": [ 167 | "uncertainty_dict['13778']" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": null, 173 | "id": "9268f88d-bdeb-4e5e-9cf7-c5046e2cbf35", 174 | "metadata": { 175 | "tags": [], 176 | "pycharm": { 177 | "name": "#%%\n" 178 | } 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "info_file = os.path.join(SAVE_PATH, dataset_name, 'caption_result', '13276.yaml')\n", 183 | "with open(info_file, 'r') as f:\n", 184 | " info = yaml.safe_load(f)" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "id": "798dbf33-cf00-4ecd-b740-ba89332ce74d", 191 | "metadata": { 192 | "pycharm": { 193 | "name": "#%%\n" 194 | } 195 | }, 196 | "outputs": [], 197 | "source": [ 198 | "questions, orig_answers = extractQA_chatgpt(info['FlanT5 XXL']['ChatCaptioner']['chat'])" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "bfad489b-610a-4750-be7c-6d058eb3ab8e", 205 | "metadata": { 206 | "tags": [], 207 | "pycharm": { 208 | "name": "#%%\n" 209 | } 210 | }, 211 | "outputs": [], 212 | "source": [ 213 | "questions" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "2a294b74-8a12-4a80-a536-68442520dd80", 220 | "metadata": { 221 | "tags": [], 222 | "pycharm": { 223 | "name": "#%%\n" 224 | } 225 | }, 226 | "outputs": [], 227 | "source": [ 228 | "orig_answers" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "8780d875-3998-4337-bdab-b7769c43cde0", 235 | "metadata": { 236 | "tags": [], 237 | "pycharm": { 238 | "name": "#%%\n" 239 | } 240 | }, 241 | "outputs": [], 242 | "source": [ 243 | "ANSWER_INSTRUCTION = 'Answer given questions. If you are not sure about the answer, say you don\\'t know honestly. Don\\'t imagine any contents that are not in the image.'\n", 244 | "ANSWER_INSTRUCTION = 'Answer given questions. Don\\'t imagine any contents that are not in the image.'\n", 245 | "SUB_ANSWER_INSTRUCTION = 'Answer: ' # template following blip2 huggingface demo" 246 | ] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": null, 251 | "id": "e8de5f91-1239-455e-826a-5314eb901fa5", 252 | "metadata": { 253 | "tags": [], 254 | "pycharm": { 255 | "name": "#%%\n" 256 | } 257 | }, 258 | "outputs": [], 259 | "source": [ 260 | "answers = []\n", 261 | "for i in range(len(questions)):\n", 262 | " print('Question: {}'.format(questions[i]))\n", 263 | " blip2_prompt = '\\n'.join([ANSWER_INSTRUCTION, \n", 264 | " get_chat_log(questions[:i+1], answers, last_n=1), \n", 265 | " SUB_ANSWER_INSTRUCTION]) \n", 266 | " answer = blip2.ask(test_img, blip2_prompt)\n", 267 | " answer = answer.split('Question:')[0].replace('\\n', ' ').strip()\n", 268 | " print('Answer: {}'.format(answer))\n", 269 | " answers.append(answer)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "id": "eb8750fb-3022-4f99-88ee-5fb385988a69", 276 | "metadata": { 277 | "pycharm": { 278 | "name": "#%%\n" 279 | } 280 | }, 281 | "outputs": [], 282 | "source": [ 283 | "results = {}\n", 284 | "\n", 285 | "# Open the CSV file for reading\n", 286 | "with open('h_uncertain.csv', 'r') as csvfile:\n", 287 | " # Create a CSV reader object\n", 288 | " csvreader = csv.DictReader(csvfile)\n", 289 | " \n", 290 | " # Iterate over each row in the CSV file\n", 291 | " for row in csvreader:\n", 292 | " # Access the values in the row by index\n", 293 | " img_id = row['Input.image_id']\n", 294 | " question = row['Input.question']\n", 295 | " tag = img_id + '_' + question\n", 296 | " answer = row['Answer.summary']\n", 297 | " \n", 298 | " if tag not in results:\n", 299 | " results[tag] = [answer]\n", 300 | " else:\n", 301 | " results[tag].append(answer)" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "id": "f9c5a762-4be8-4041-8287-5c652f8bb3c5", 308 | "metadata": { 309 | "tags": [], 310 | "pycharm": { 311 | "name": "#%%\n" 312 | } 313 | }, 314 | "outputs": [], 315 | "source": [ 316 | "len(results)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": null, 322 | "id": "2d5edbf0-d9a9-4638-85ec-db9906512a10", 323 | "metadata": { 324 | "tags": [], 325 | "pycharm": { 326 | "name": "#%%\n" 327 | } 328 | }, 329 | "outputs": [], 330 | "source": [ 331 | "uncertainQ = []\n", 332 | "certainQ = []\n", 333 | "for tag, answers in results.items():\n", 334 | " n_none = 0\n", 335 | " for answer in answers:\n", 336 | " if 'none' in answer.lower():\n", 337 | " n_none += 1\n", 338 | " if n_none >= 2:\n", 339 | " uncertainQ.append(tag)\n", 340 | " else:\n", 341 | " certainQ.append([tag, answers])" 342 | ] 343 | }, 344 | { 345 | "cell_type": "code", 346 | "execution_count": null, 347 | "id": "9ee3fc0f-1654-4287-ba12-60eb328b0295", 348 | "metadata": { 349 | "tags": [], 350 | "pycharm": { 351 | "name": "#%%\n" 352 | } 353 | }, 354 | "outputs": [], 355 | "source": [ 356 | "certain_img = {}\n", 357 | "for tag, h_answers in certainQ:\n", 358 | " img_id, question = tag.split('_')\n", 359 | " if img_id in certain_img:\n", 360 | " certain_img[img_id][question] = h_answers\n", 361 | " else:\n", 362 | " certain_img[img_id] = {question: h_answers}\n", 363 | " " 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "id": "a1b78acd-a1c7-47a2-a4a9-ad529f2dd352", 370 | "metadata": { 371 | "tags": [], 372 | "pycharm": { 373 | "name": "#%%\n" 374 | } 375 | }, 376 | "outputs": [], 377 | "source": [ 378 | "len(certainQ)" 379 | ] 380 | }, 381 | { 382 | "cell_type": "code", 383 | "execution_count": null, 384 | "id": "2c8aca85-2b44-4022-bc10-bd7f83eb9510", 385 | "metadata": { 386 | "tags": [], 387 | "pycharm": { 388 | "name": "#%%\n" 389 | } 390 | }, 391 | "outputs": [], 392 | "source": [ 393 | "ANSWER_INSTRUCTION = 'Answer given questions. Don\\'t imagine any contents that are not in the image.'\n", 394 | "SUB_ANSWER_INSTRUCTION = 'Answer: ' # template following blip2 huggingface demo\n", 395 | "\n", 396 | "for img_id, c_questions in certain_img.items():\n", 397 | " with open(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '{}.yaml'.format(img_id)), 'r') as f:\n", 398 | " info = yaml.safe_load(f)\n", 399 | " test_img, _ = dataset.fetch_img(img_id)\n", 400 | "\n", 401 | " chat = info['FlanT5 XXL']['ChatCaptioner']['chat']\n", 402 | " questions, _ = extractQA_chatgpt(chat)\n", 403 | "\n", 404 | " answers = []\n", 405 | " for i in range(len(questions)):\n", 406 | " if questions[i] in c_questions:\n", 407 | " print('?????????????????')\n", 408 | " print('Question: {}'.format(questions[i]))\n", 409 | " blip2_prompt = '\\n'.join([ANSWER_INSTRUCTION, \n", 410 | " get_chat_log(questions[:i+1], answers, last_n=1), \n", 411 | " SUB_ANSWER_INSTRUCTION]) \n", 412 | " answer = blip2.ask(test_img, blip2_prompt)\n", 413 | " answer = answer.split('Question:')[0].replace('\\n', ' ').strip()\n", 414 | " answers.append(answer)\n", 415 | " print('Answer: {}'.format(answer))\n", 416 | " if questions[i] in c_questions:\n", 417 | " for h_answer in c_questions[questions[i]]:\n", 418 | " print('Human: {}'.format(h_answer))\n", 419 | " print('!!!!!!!!!!!!!!!!!!!!')\n", 420 | " plot_img(test_img)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "id": "662a9d64-638b-4861-90e0-fe1479b60a1f", 427 | "metadata": { 428 | "pycharm": { 429 | "name": "#%%\n" 430 | } 431 | }, 432 | "outputs": [], 433 | "source": [] 434 | } 435 | ], 436 | "metadata": { 437 | "kernelspec": { 438 | "display_name": "chatae", 439 | "language": "python", 440 | "name": "chatae" 441 | }, 442 | "language_info": { 443 | "codemirror_mode": { 444 | "name": "ipython", 445 | "version": 3 446 | }, 447 | "file_extension": ".py", 448 | "mimetype": "text/x-python", 449 | "name": "python", 450 | "nbconvert_exporter": "python", 451 | "pygments_lexer": "ipython3", 452 | "version": "3.9.16" 453 | } 454 | }, 455 | "nbformat": 4, 456 | "nbformat_minor": 5 457 | } -------------------------------------------------------------------------------- /ChatCaptioner/obj_cover.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "1f65513e-6d70-452f-8d75-d2cd2b51f8fc", 7 | "metadata": { 8 | "tags": [], 9 | "pycharm": { 10 | "name": "#%%\n" 11 | } 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "import os\n", 16 | "from glob import glob\n", 17 | "import yaml\n", 18 | "from tqdm import tqdm\n", 19 | "import nltk\n", 20 | "from nltk.corpus import wordnet\n", 21 | "from sentence_transformers import SentenceTransformer, util\n", 22 | "\n", 23 | "from chatcaptioner.utils import RandomSampledDataset, print_info, plot_img" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": null, 29 | "id": "11da03c3-c1c3-4614-b91a-693ed3ebc598", 30 | "metadata": { 31 | "tags": [], 32 | "pycharm": { 33 | "name": "#%%\n" 34 | } 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "def map_word_to_hypernym(word):\n", 39 | " synsets = wordnet.synsets(word)\n", 40 | " if len(synsets) == 0:\n", 41 | " return word\n", 42 | " else:\n", 43 | " synset = synsets[0] # Use first synset as default\n", 44 | " hypernyms = synset.hypernyms()\n", 45 | " if len(hypernyms) == 0:\n", 46 | " return word\n", 47 | " else:\n", 48 | " hypernym = hypernyms[0] # Use first hypernym as default\n", 49 | " return hypernym.lemmas()[0].name()" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "id": "19178916-28c2-46e9-9abe-a00377c8a8da", 56 | "metadata": { 57 | "tags": [], 58 | "pycharm": { 59 | "name": "#%%\n" 60 | } 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "def is_included(noun1, noun2):\n", 65 | " synsets1 = wordnet.synsets(noun1, pos=wordnet.NOUN)\n", 66 | " synsets2 = wordnet.synsets(noun2, pos=wordnet.NOUN)\n", 67 | " \n", 68 | " for synset1 in synsets1:\n", 69 | " for synset2 in synsets2:\n", 70 | " # Check for similarity score\n", 71 | " similarity_score = synset1.wup_similarity(synset2)\n", 72 | " if similarity_score is not None and similarity_score > 0.9:\n", 73 | " return True\n", 74 | " # Check for inclusion relationship\n", 75 | " if synset1 in synset2.closure(lambda s: s.hyponyms()) \\\n", 76 | " or synset2 in synset1.closure(lambda s: s.hyponyms()):\n", 77 | " return True\n", 78 | " return False" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "6e9d2628-cb6d-41da-b05e-a6e4d5ecbab9", 85 | "metadata": { 86 | "tags": [], 87 | "pycharm": { 88 | "name": "#%%\n" 89 | } 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "def extract_nouns(text):\n", 94 | " nouns = []\n", 95 | " sentences = nltk.sent_tokenize(text)\n", 96 | " for sentence in sentences:\n", 97 | " words = nltk.word_tokenize(sentence)\n", 98 | " tagged_words = nltk.pos_tag(words)\n", 99 | " for word, tag in tagged_words:\n", 100 | " if tag.startswith('N'): # Nouns start with 'N' in POS tag\n", 101 | " nouns.append(word)\n", 102 | " return nouns" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "9745507d-98ba-441a-956b-fa57a7ca3460", 109 | "metadata": { 110 | "tags": [], 111 | "pycharm": { 112 | "name": "#%%\n" 113 | } 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "sentence_model = SentenceTransformer('all-mpnet-base-v2')" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "id": "bfccb971-21a3-48d2-bb63-879616a83542", 124 | "metadata": { 125 | "tags": [], 126 | "pycharm": { 127 | "name": "#%%\n" 128 | } 129 | }, 130 | "outputs": [], 131 | "source": [ 132 | "DATA_ROOT = 'datasets'\n", 133 | "dataset = RandomSampledDataset(DATA_ROOT, 'pascal')" 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": null, 139 | "id": "ac646f4f-b347-40b8-bce7-8cfeaf9631c0", 140 | "metadata": { 141 | "tags": [], 142 | "pycharm": { 143 | "name": "#%%\n" 144 | } 145 | }, 146 | "outputs": [], 147 | "source": [ 148 | "# specify SAVE_PATH to visualize the result you want\n", 149 | "SAVE_PATH = 'experiments/test'" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "366836da-0b61-4dff-b8c9-688065484aa0", 156 | "metadata": { 157 | "tags": [], 158 | "pycharm": { 159 | "name": "#%%\n" 160 | } 161 | }, 162 | "outputs": [], 163 | "source": [ 164 | "def check_cover(gt_objs, cap_objs):\n", 165 | " covered = []\n", 166 | " for gt_obj in gt_objs:\n", 167 | " for obj in cap_objs:\n", 168 | " if obj == 'people':\n", 169 | " obj = 'person'\n", 170 | " if is_included(gt_obj, obj):\n", 171 | " covered.append(gt_obj)\n", 172 | " break\n", 173 | " return len(covered), len(gt_objs)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "22d347a0-b552-4510-bd28-42211921620e", 180 | "metadata": { 181 | "tags": [], 182 | "pycharm": { 183 | "name": "#%%\n" 184 | } 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "results_blip2 = []\n", 189 | "results_our = []\n", 190 | "\n", 191 | "save_infos = glob(os.path.join(SAVE_PATH, 'pascal', 'caption_result', '*'))\n", 192 | "for info_file in tqdm(save_infos):\n", 193 | " with open(info_file, 'r') as f:\n", 194 | " info = yaml.safe_load(f)\n", 195 | " img_id = info['id'] if 'id' in info else info['setting']['id']\n", 196 | " \n", 197 | " blip2 = info['FlanT5 XXL']['BLIP2+OurPrompt']['caption']\n", 198 | " blip2 = extract_nouns(blip2)\n", 199 | " \n", 200 | " our = info['FlanT5 XXL']['ChatCaptioner']['caption']\n", 201 | " our = extract_nouns(our)\n", 202 | " \n", 203 | " gt_objs = []\n", 204 | " gt_objs_tmp = info['setting']['GT']['caption'][0].split('_')\n", 205 | " \n", 206 | " for obj in gt_objs_tmp:\n", 207 | " if ' ' in obj: continue\n", 208 | " gt_objs.append(obj)\n", 209 | " \n", 210 | " results_blip2.append(check_cover(gt_objs, blip2))\n", 211 | " results_our.append(check_cover(gt_objs, our))\n", 212 | " \n" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": null, 218 | "id": "641b5266-ca6b-4f36-9c79-214c0691e1c1", 219 | "metadata": { 220 | "tags": [], 221 | "pycharm": { 222 | "name": "#%%\n" 223 | } 224 | }, 225 | "outputs": [], 226 | "source": [ 227 | "x, y = 0, 0\n", 228 | "for a, b in results_our:\n", 229 | " x += a\n", 230 | " y += b\n", 231 | "print(x, y)" 232 | ] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "execution_count": null, 237 | "id": "5ba6cd2e-b7e7-495d-8f7e-6927eb38e723", 238 | "metadata": { 239 | "tags": [], 240 | "pycharm": { 241 | "name": "#%%\n" 242 | } 243 | }, 244 | "outputs": [], 245 | "source": [ 246 | "x, y = 0, 0\n", 247 | "for a, b in results_blip2:\n", 248 | " x += a\n", 249 | " y += b\n", 250 | "print(x, y)" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "id": "67e5e1fc-6d7e-4449-8623-e1d97ca8ef92", 257 | "metadata": { 258 | "tags": [], 259 | "pycharm": { 260 | "name": "#%%\n" 261 | } 262 | }, 263 | "outputs": [], 264 | "source": [ 265 | "with open(info_file, 'r') as f:\n", 266 | " info = yaml.safe_load(f)\n", 267 | "img_id = info['id'] if 'id' in info else info['setting']['id']\n", 268 | "\n", 269 | "blip2 = info['FlanT5 XXL']['BLIP2+OurPrompt']['caption']\n", 270 | "blip2 = extract_nouns(blip2)\n", 271 | "\n", 272 | "our = info['FlanT5 XXL']['ChatCaptioner']['caption']\n", 273 | "our = extract_nouns(our)\n", 274 | "\n", 275 | "gt_objs = []\n", 276 | "gt_objs_tmp = info['setting']['GT']['caption'][0].split('_')\n", 277 | " " 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "id": "8bb32496-6aee-4d71-b060-c068388e27ad", 284 | "metadata": { 285 | "tags": [], 286 | "pycharm": { 287 | "name": "#%%\n" 288 | } 289 | }, 290 | "outputs": [], 291 | "source": [ 292 | "blip2" 293 | ] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "execution_count": null, 298 | "id": "b2de57fe-f287-4703-914f-5a8feb71c2e9", 299 | "metadata": { 300 | "tags": [], 301 | "pycharm": { 302 | "name": "#%%\n" 303 | } 304 | }, 305 | "outputs": [], 306 | "source": [ 307 | "our" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "id": "1003c60b-ad08-4d5b-925f-5b61ddb33a64", 314 | "metadata": { 315 | "tags": [], 316 | "pycharm": { 317 | "name": "#%%\n" 318 | } 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "gt_objs_tmp" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": null, 328 | "id": "fa36cf8a-8d35-43cc-ac98-db631154e2a7", 329 | "metadata": { 330 | "pycharm": { 331 | "name": "#%%\n" 332 | } 333 | }, 334 | "outputs": [], 335 | "source": [] 336 | } 337 | ], 338 | "metadata": { 339 | "kernelspec": { 340 | "display_name": "chatae", 341 | "language": "python", 342 | "name": "chatae" 343 | }, 344 | "language_info": { 345 | "codemirror_mode": { 346 | "name": "ipython", 347 | "version": 3 348 | }, 349 | "file_extension": ".py", 350 | "mimetype": "text/x-python", 351 | "name": "python", 352 | "nbconvert_exporter": "python", 353 | "pygments_lexer": "ipython3", 354 | "version": "3.9.16" 355 | } 356 | }, 357 | "nbformat": 4, 358 | "nbformat_minor": 5 359 | } -------------------------------------------------------------------------------- /ChatCaptioner/question_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "cbec04d8-10b6-4a31-bf18-dc143d2ea554", 7 | "metadata": { 8 | "pycharm": { 9 | "name": "#%%\n" 10 | }, 11 | "tags": [] 12 | }, 13 | "outputs": [], 14 | "source": [ 15 | "import os\n", 16 | "import numpy as np\n", 17 | "from tqdm import tqdm\n", 18 | "from glob import glob\n", 19 | "import yaml\n", 20 | "import matplotlib\n", 21 | "from matplotlib.colors import hsv_to_rgb\n", 22 | "import matplotlib.pyplot as plt\n", 23 | "import matplotlib.patheffects as PathEffects\n", 24 | "import umap\n", 25 | "from sentence_transformers import SentenceTransformer, util\n", 26 | "\n", 27 | "from chatcaptioner.utils import extractQA_chatgpt" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "id": "f0ba1933-b16f-49dc-bd75-9ebecaf6c1a1", 34 | "metadata": { 35 | "pycharm": { 36 | "name": "#%%\n" 37 | }, 38 | "tags": [] 39 | }, 40 | "outputs": [], 41 | "source": [ 42 | "# specify SAVE_PATH to visualize the result you want\n", 43 | "SAVE_PATH = 'experiments/test/'\n", 44 | "DATA_ROOT = 'datasets/'\n", 45 | "sentence_model = SentenceTransformer('all-mpnet-base-v2')" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "bf190fec-8a4d-41c5-8c90-fe8181fe6c1b", 52 | "metadata": { 53 | "pycharm": { 54 | "name": "#%%\n" 55 | }, 56 | "tags": [] 57 | }, 58 | "outputs": [], 59 | "source": [ 60 | "datasets_list = os.listdir(SAVE_PATH)\n", 61 | "datasets_list = ['cc_val']\n", 62 | "all_questions = []\n", 63 | "effect_q = []\n", 64 | "for dataset_name in datasets_list:\n", 65 | " print('============================')\n", 66 | " print(' {} '.format(dataset_name))\n", 67 | " print('============================')\n", 68 | " \n", 69 | " \n", 70 | " save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))\n", 71 | " for info_file in save_infos:\n", 72 | " with open(info_file, 'r') as f:\n", 73 | " info = yaml.safe_load(f)\n", 74 | " chat = info['FlanT5 XXL']['ChatCaptioner']['chat']\n", 75 | " if isinstance(chat, str):\n", 76 | " questions = []\n", 77 | " sentences = info['FlanT5 XXL']['ChatCaptioner']['chat'].split('\\n')\n", 78 | " for sentence in sentences:\n", 79 | " if 'Question: Describe this image in details.' in sentence: continue\n", 80 | " if 'Question:' in sentence:\n", 81 | " questions.append(sentence.split('Question:')[-1].strip())\n", 82 | " effect_q.append(len(set(questions)))\n", 83 | " all_questions += questions\n", 84 | " else:\n", 85 | " questions, answers = extractQA_chatgpt(chat)\n", 86 | " effect_q.append(len(set(questions[1:])))\n", 87 | " all_questions += questions[1:]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "id": "56dddea7", 94 | "metadata": { 95 | "pycharm": { 96 | "name": "#%%\n" 97 | } 98 | }, 99 | "outputs": [], 100 | "source": [ 101 | "print('Unique Q/ Total Q: {}/{}'.format(len(set(all_questions)), len(all_questions)))\n", 102 | "print('Average Unique Q Per Dialogue: {}'.format(sum(effect_q) / len(effect_q)))" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "id": "0180992d-ce17-4a99-8d78-cb611bc6a7a0", 109 | "metadata": { 110 | "pycharm": { 111 | "name": "#%%\n" 112 | }, 113 | "tags": [] 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "all_embs = []\n", 118 | "for question in tqdm(all_questions):\n", 119 | " all_embs.append(sentence_model.encode(question))\n", 120 | "all_embs = np.stack(all_embs)" 121 | ] 122 | }, 123 | { 124 | "cell_type": "code", 125 | "execution_count": null, 126 | "id": "808f3991-d24d-44ba-adcc-2185dc3c914f", 127 | "metadata": { 128 | "pycharm": { 129 | "name": "#%%\n" 130 | }, 131 | "tags": [] 132 | }, 133 | "outputs": [], 134 | "source": [ 135 | "fit = umap.UMAP()\n", 136 | "fit_color = umap.UMAP(n_components=1)\n", 137 | "%time u = fit.fit_transform(all_embs)\n", 138 | "%time c = fit_color.fit_transform(all_embs)\n", 139 | "norm_c = (c - c.min())/ (c.max()-c.min())" 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "693475fb-27c9-4e67-9d89-8d0943e78078", 146 | "metadata": { 147 | "pycharm": { 148 | "name": "#%%\n" 149 | }, 150 | "tags": [] 151 | }, 152 | "outputs": [], 153 | "source": [ 154 | "cmap = matplotlib.colormaps['gnuplot2']" 155 | ] 156 | }, 157 | { 158 | "cell_type": "code", 159 | "execution_count": null, 160 | "id": "15615094-706f-418e-bb84-203335069041", 161 | "metadata": { 162 | "pycharm": { 163 | "name": "#%%\n" 164 | }, 165 | "tags": [] 166 | }, 167 | "outputs": [], 168 | "source": [ 169 | "plt.scatter(u[:, 0], u[:, 1], s=8, alpha=0.5, c=norm_c, cmap='gnuplot2')\n", 170 | "plt.xlim(6, 21)\n", 171 | "plt.ylim(-1, 14)\n", 172 | "plt.axis('off')\n", 173 | "plt.show()" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "e352809f-fa7a-4ba0-abd6-912dac7e6fda", 180 | "metadata": { 181 | "pycharm": { 182 | "name": "#%%\n" 183 | }, 184 | "tags": [] 185 | }, 186 | "outputs": [], 187 | "source": [ 188 | "random_ids = random.sample(range(len(all_questions)), 5)\n", 189 | "for q_id in random_ids:\n", 190 | " print('{}: {}'.format(q_id, all_questions[q_id]))\n", 191 | "\n", 192 | "plt.scatter(u[:, 0], u[:, 1], s=1, c=norm_c, cmap='gnuplot2')\n", 193 | "plt.xlim(6, 21)\n", 194 | "plt.ylim(-1, 14)\n", 195 | "for q_id in random_ids:\n", 196 | " plt.text(x=u[q_id, 0], y=u[q_id, 1], s=all_questions[q_id], \n", 197 | " ha='center', wrap=True, \n", 198 | " c=cmap(norm_c[q_id])\n", 199 | " )\n", 200 | " txt.set_bbox(dict(facecolor='white', alpha=0.8, edgecolor='white'))\n", 201 | "plt.axis('off')\n", 202 | "plt.show()" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "fbbea4e2-acfe-472d-8554-8849693821b8", 209 | "metadata": { 210 | "pycharm": { 211 | "name": "#%%\n" 212 | }, 213 | "tags": [] 214 | }, 215 | "outputs": [], 216 | "source": [] 217 | } 218 | ], 219 | "metadata": { 220 | "kernelspec": { 221 | "display_name": "chatae", 222 | "language": "python", 223 | "name": "chatae" 224 | }, 225 | "language_info": { 226 | "codemirror_mode": { 227 | "name": "ipython", 228 | "version": 3 229 | }, 230 | "file_extension": ".py", 231 | "mimetype": "text/x-python", 232 | "name": "python", 233 | "nbconvert_exporter": "python", 234 | "pygments_lexer": "ipython3", 235 | "version": "3.9.16" 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 5 240 | } -------------------------------------------------------------------------------- /ChatCaptioner/yes_no.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "5c44a676-0233-4475-bb98-a81fac693899", 7 | "metadata": { 8 | "tags": [] 9 | }, 10 | "outputs": [], 11 | "source": [ 12 | "%load_ext autoreload\n", 13 | "%autoreload 2" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "03947bd6-59ca-4b9c-a30b-33dbeea8fc54", 20 | "metadata": { 21 | "tags": [] 22 | }, 23 | "outputs": [], 24 | "source": [ 25 | "import os\n", 26 | "from glob import glob\n", 27 | "import csv\n", 28 | "import yaml\n", 29 | "from chatcaptioner.chat import get_chat_log\n", 30 | "from chatcaptioner.blip2 import Blip2\n", 31 | "from chatcaptioner.utils import print_info, plot_img, extractQA_chatgpt, RandomSampledDataset" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "27f4088d-3886-43e0-9dc9-0a8b9b946fae", 38 | "metadata": { 39 | "tags": [] 40 | }, 41 | "outputs": [], 42 | "source": [ 43 | "# specify SAVE_PATH to visualize the result you want\n", 44 | "SAVE_PATH = 'experiments/test/'\n", 45 | "DATA_ROOT = 'datasets/'" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "id": "eba38f95-3768-4b26-8881-e5693c65a13e", 52 | "metadata": { 53 | "tags": [] 54 | }, 55 | "outputs": [], 56 | "source": [ 57 | "datasets_list = os.listdir(SAVE_PATH)\n", 58 | "datasets_list = ['cc_val']\n", 59 | "yes_no_list = []\n", 60 | "\n", 61 | "for dataset_name in datasets_list:\n", 62 | " print('============================')\n", 63 | " print(' {} '.format(dataset_name))\n", 64 | " print('============================')\n", 65 | " dataset = RandomSampledDataset(DATA_ROOT, dataset_name)\n", 66 | " \n", 67 | " save_infos = glob(os.path.join(SAVE_PATH, dataset_name, 'caption_result', '*'))\n", 68 | " for info_file in save_infos:\n", 69 | " with open(info_file, 'r') as f:\n", 70 | " info = yaml.safe_load(f)\n", 71 | " img_id = info['id'] if 'id' in info else info['setting']['id']\n", 72 | " test_img, _ = dataset.fetch_img(img_id)\n", 73 | " \n", 74 | " chat = info['FlanT5 XXL']['ChatCaptioner']['chat']\n", 75 | " questions, answers = extractQA_chatgpt(chat)\n", 76 | " yes_no = False\n", 77 | " for q, a in zip(questions, answers):\n", 78 | " a = a.lower()\n", 79 | " # a = a.split(' ')\n", 80 | " if 'yes' == a or 'no' == a:\n", 81 | " # if 'Are there' in q:\n", 82 | " yes_no = True\n", 83 | " print('Question: {}'.format(q))\n", 84 | " print('Answer: {}'.format(a))\n", 85 | " yes_no_list.append((img_id, q, a))\n", 86 | " if not_sure:\n", 87 | " plot_img(test_img)\n", 88 | " " 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "id": "b717d01e-fbfc-44b3-8690-542d417a8c26", 95 | "metadata": { 96 | "tags": [] 97 | }, 98 | "outputs": [], 99 | "source": [ 100 | "# save questions that BLIP-2 is unsure about for later human evalutaion\n", 101 | "uncertainty_dict = {}\n", 102 | "for img_id, q, a in uncertainty_list:\n", 103 | " if img_id not in uncertainty_dict:\n", 104 | " uncertainty_dict[img_id] = [q]\n", 105 | " else:\n", 106 | " uncertainty_dict[img_id].append(q)\n", 107 | "with open(os.path.join('not_sure.yaml'), 'w') as f:\n", 108 | " yaml.dump(uncertainty_dict, f)" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": "chatae", 115 | "language": "python", 116 | "name": "chatae" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.9.16" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 5 133 | } 134 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | 2 | MIT License 3 | 4 | Copyright (c) 2023 Deyao Zhu 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Interactive ChatCaptioner for image and video 2 | 3 | Official repository of **ChatCaptioner** and **Video ChatCaptioner**. 4 | 5 | ChatCaptioner paper [ChatGPT Asks, BLIP-2 Answers: Automatic Questioning Towards Enriched Visual Descriptions](https://arxiv.org/abs/2303.06594) 6 | 7 | Video ChatCaptioner paper [Video ChatCaptioner: Towards the Enriched Spatiotemporal Descriptions](https://arxiv.org/abs/2304.04227) 8 | 9 | ## Demo 10 | ![demo1](ChatCaptioner/demo_pic/demo1.gif) 11 | ![demo2](ChatCaptioner/demo_pic/demo2.gif) 12 | ![demo3](Video_ChatCaptioner/demo_pic/dance.gif) 13 | ![demo4](Video_ChatCaptioner/demo_pic/skating.gif) 14 | 15 | 16 | 17 | * ChatCaptiners: 18 | * ChatCaptioner for images: [ChatCaptioner](ChatCaptioner/README.md) 19 | * ChatCaptioner for videos: [Video ChatCaptioner](Video_ChatCaptioner/README.md) 20 | 21 | 22 | 23 | 24 | ## Acknowledgement 25 | 26 | + [ChatGPT](https://openai.com/blog/chatgpt/) 27 | + [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) 28 | 29 | 30 | Please cite ChatCaptioner and Video ChatCaptioner from the following bibtex 31 | 32 | ``` 33 | @article{zhu2023chatgpt, 34 | title={ChatGPT Asks, BLIP-2 Answers: Automatic Questioning Towards Enriched Visual Descriptions}, 35 | author={Deyao Zhu and Jun Chen and Kilichbek Haydarov and Xiaoqian Shen and Wenxuan Zhang and Mohamed Elhoseiny}, 36 | journal={arXiv preprint arXiv:2303.06594}, 37 | year={2023} 38 | } 39 | ``` 40 | 41 | 42 | ``` 43 | @article{chen2023video, 44 | title={Video ChatCaptioner: Towards the Enriched Spatiotemporal Descriptions}, 45 | author={Jun Chen and Deyao Zhu and Kilichbek Haydarov and Xiang Li and Mohamed Elhoseiny}, 46 | journal={arXiv preprint arXiv:2304.04227}, 47 | year={2023} 48 | } 49 | ``` 50 | 51 | 52 | ## *License* 53 | 54 | ChatCaptioner and Video ChatCaptioner are released under the [MIT license](LICENSE). 55 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/README.md: -------------------------------------------------------------------------------- 1 | # Video ChatCaptioner: Towards the Enriched Spatiotemporal Descriptions 2 | 3 | Official repository of **Video ChatCaptioner**. 4 | 5 | See our paper [Video ChatCaptioner: Towards the Enriched Spatiotemporal Descriptions](https://arxiv.org/abs/2304.04227) 6 | 7 | 8 | ![demo1](demo_pic/dance.gif) 9 | ![demo2](demo_pic/skating.gif) 10 | 11 | 12 | ## System Architecture 13 | ![overfiew](demo_pic/overview.png) 14 | 15 | 16 | 17 | ## Installation 18 | Note that you need a GPU with 24G memory to run ChatCaptioner due to the size of BLIP-2. 19 | 20 | To start, git clone this repository first. 21 | 22 | To install and activate the environment, run the following command: 23 | 24 | ``` 25 | conda env create -f environment.yml 26 | conda activate chatcap 27 | ``` 28 | 29 | Set the environment variable OPENAI_API_KEY to your OpenAI API Key. 30 | 31 | ``` 32 | export OPENAI_API_KEY=Your_OpenAI_Key 33 | ``` 34 | You can add it to .bashrc so you don't need to set it manually everytime. 35 | 36 | 37 | As many scripts here are in jupyter notebook, don't forget to add the environment to jupyter's kernel list. 38 | To do so, run 39 | 40 | ``` 41 | python -m ipykernel install --user --name=chatcap 42 | ``` 43 | 44 | 45 | Download our dataset samples from [here](https://drive.google.com/drive/folders/1NtGtz_CbZJFxvbuV_AYx3Acy6-Q-p8Eg?usp=sharing) and extract the zip file to the root folder. 46 | 47 | 48 | To play with Video ChatCaptioner with a few dataset samples on msvd videos 49 | 50 | ``` 51 | sh run_msvd.sh 52 | ``` 53 | 54 | To play with Video ChatCaptioner with a few dataset samples on webvid videos 55 | 56 | ``` 57 | sh run_webvid.sh 58 | ``` 59 | 60 | 61 | 66 | 67 | 68 | 69 | ## Acknowledgement 70 | 71 | + [ChatGPT](https://openai.com/blog/chatgpt/) 72 | + [BLIP2](https://huggingface.co/docs/transformers/main/model_doc/blip-2) 73 | 74 | Please cite Video ChatCaptioner from the following bibtex 75 | 76 | 77 | ``` 78 | @article{chen2023video, 79 | title={Video ChatCaptioner: Towards the Enriched Spatiotemporal Descriptions}, 80 | author={Jun Chen and Deyao Zhu and Kilichbek Haydarov and Xiang Li and Mohamed Elhoseiny}, 81 | journal={arXiv preprint arXiv:2304.04227}, 82 | year={2023} 83 | } 84 | ``` -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/Video_ChatCaptioner/chatcaptioner/__init__.py -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/blip2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Blip2Processor, Blip2ForConditionalGeneration 3 | 4 | 5 | BLIP2DICT = { 6 | 'FlanT5 XXL': 'Salesforce/blip2-flan-t5-xxl', 7 | 'FlanT5 XL COCO': 'Salesforce/blip2-flan-t5-xl-coco', 8 | 'OPT6.7B COCO': 'Salesforce/blip2-opt-6.7b-coco', 9 | 'OPT2.7B COCO': 'Salesforce/blip2-opt-2.7b-coco', 10 | 'FlanT5 XL': 'Salesforce/blip2-flan-t5-xl', 11 | 'OPT6.7B': 'Salesforce/blip2-opt-6.7b', 12 | 'OPT2.7B': 'Salesforce/blip2-opt-2.7b', 13 | } 14 | 15 | 16 | class Blip2(): 17 | def __init__(self, model, device_id, bit8=True): 18 | # load BLIP-2 to a single gpu 19 | self.tag = model 20 | self.bit8 = bit8 21 | self.device = 'cuda:{}'.format(device_id) 22 | 23 | dtype = {'load_in_8bit': True} if self.bit8 else {'torch_dtype': torch.float16} 24 | self.blip2_processor = Blip2Processor.from_pretrained(BLIP2DICT[self.tag]) 25 | self.blip2 = Blip2ForConditionalGeneration.from_pretrained(BLIP2DICT[self.tag], device_map={'': device_id}, **dtype) 26 | 27 | def ask(self, raw_image, question): 28 | inputs = self.blip2_processor(raw_image, question, return_tensors="pt").to(self.device, torch.float16) 29 | out = self.blip2.generate(**inputs) 30 | answer = self.blip2_processor.decode(out[0], skip_special_tokens=True) 31 | return answer 32 | 33 | def caption(self, raw_image): 34 | # starndard way to caption an image in the blip2 paper 35 | caption = self.ask(raw_image, 'a photo of') 36 | caption = caption.replace('\n', ' ').strip() # trim caption 37 | return caption 38 | 39 | def call_llm(self, prompts): 40 | prompts_temp = self.blip2_processor(None, prompts, return_tensors="pt") 41 | input_ids = prompts_temp['input_ids'].to(self.device) 42 | attention_mask = prompts_temp['attention_mask'].to(self.device, torch.float16) 43 | 44 | prompts_embeds = self.blip2.language_model.get_input_embeddings()(input_ids) 45 | 46 | outputs = self.blip2.language_model.generate( 47 | inputs_embeds=prompts_embeds, 48 | attention_mask=attention_mask) 49 | 50 | outputs = self.blip2_processor.decode(outputs[0], skip_special_tokens=True) 51 | 52 | return outputs 53 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/clip.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | 3 | class ClipScore(): 4 | def __init__(self, device='cuda:0'): 5 | # load open clip to device 6 | self.device = device 7 | clip, _, self.clip_preprocess = open_clip.create_model_and_transforms('ViT-H-14', pretrained='laion2b_s32b_b79k') 8 | self.clip = clip.to(self.device) 9 | self.clip_tokenizer = open_clip.get_tokenizer('ViT-H-14') 10 | 11 | 12 | def clip_IT_score(self, image, texts): 13 | ''' 14 | compute the average similarity score of a given image and a list of texts 15 | ''' 16 | if isinstance(texts, str): 17 | texts = [texts] 18 | image = self.clip_preprocess(image)[None].to(self.device) 19 | texts = self.clip_tokenizer(texts).to(self.device) 20 | with torch.no_grad(): 21 | image_f = self.clip.encode_image(image).float() 22 | texts_f = self.clip.encode_text(texts).float() 23 | image_f /= image_f.norm(dim=-1, keepdim=True) 24 | texts_f /= texts_f.norm(dim=-1, keepdim=True) 25 | similarity = (image_f.cpu().numpy() @ texts_f.cpu().numpy().T).mean() 26 | similarity = round(float(similarity), 3) 27 | return similarity -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import deepcopy 4 | import yaml 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | from pycocotools.coco import COCO 8 | 9 | 10 | 11 | class COCOHelper(): 12 | def __init__(self, coco_path, coco_ann_path): 13 | # todo: make it works for test set. test set doesn't contain annotation 14 | self.coco_path = coco_path 15 | self.coco_ann = COCO(annotation_file=coco_ann_path) 16 | self.coco_ids = self.coco_ann.getImgIds() 17 | # self.split = split 18 | 19 | def random_img_ids(self, n): 20 | sample_img_ids = random.sample(self.coco_ids, n) 21 | return sample_img_ids 22 | 23 | def fetch_coco_img(self, image_id, split='val'): 24 | img_name = '%012d.jpg' % image_id 25 | img_path = os.path.join(self.coco_path, img_name) 26 | raw_image = Image.open(img_path).convert('RGB') 27 | 28 | ann_ids = self.coco_ann.getAnnIds(imgIds=[image_id], iscrowd=None) 29 | anns = self.coco_ann.loadAnns(ann_ids) 30 | captions = [ann['caption'] for ann in anns] 31 | return raw_image, captions 32 | 33 | 34 | class RandomSampledDataset(): 35 | def __init__(self, datasets_root, dataset_name): 36 | self.name = dataset_name 37 | self.dataset_path = os.path.join(datasets_root, dataset_name) 38 | self._ids = [file_name.split('.jpg')[0] for file_name in os.listdir(os.path.join(self.dataset_path, 'img'))] 39 | 40 | 41 | ann_path = os.path.join(datasets_root, dataset_name, 'annotation.yaml') 42 | if os.path.exists(ann_path): 43 | with open(ann_path, 'r') as f: 44 | self.ann = yaml.safe_load(f) 45 | if isinstance(list(self.ann.keys())[0], int): 46 | self.ann = {str(image_id): captions for image_id, captions in self.ann.items()} 47 | else: 48 | self.ann = None 49 | 50 | @property 51 | def ids(self): 52 | return deepcopy(self._ids) 53 | 54 | def random_img_ids(self, n): 55 | sample_img_ids = random.sample(self._ids, n) 56 | return sample_img_ids 57 | 58 | def fetch_img(self, image_id): 59 | img_path = os.path.join(self.dataset_path, 'img', '{}.jpg'.format(image_id)) 60 | raw_image = Image.open(img_path).convert('RGB') 61 | 62 | if self.ann: 63 | captions = self.ann[image_id] 64 | 65 | if isinstance(captions, str): 66 | captions = [captions] 67 | else: 68 | captions = [] 69 | 70 | return raw_image, captions 71 | 72 | 73 | class SimPairDataset(): 74 | def __init__(self, datasets_root, dataset_name): 75 | self.name = dataset_name 76 | self.dataset_path = os.path.join(datasets_root, dataset_name) 77 | 78 | ann_path = os.path.join(datasets_root, dataset_name, 'sim_retrieve.yaml') 79 | if os.path.exists(ann_path): 80 | with open(ann_path, 'r') as f: 81 | self.ann = yaml.safe_load(f) 82 | if isinstance(list(self.ann.keys())[0], int): 83 | self.ann = {str(image_id): captions for image_id, captions in self.ann.items()} 84 | else: 85 | self.ann = None 86 | self._ids = list(self.ann.keys()) 87 | 88 | @property 89 | def ids(self): 90 | return deepcopy(self._ids) 91 | 92 | def fetch_img_pairs(self, pair_id): 93 | image_ids = list(self.ann[pair_id].keys()) 94 | fetched = [] 95 | for image_id in image_ids: 96 | img_path = os.path.join(self.dataset_path, 'img', '{}.jpg'.format(image_id)) 97 | raw_image = Image.open(img_path).convert('RGB') 98 | if self.ann: 99 | captions = self.ann[pair_id][image_id] 100 | 101 | if isinstance(captions, str): 102 | captions = [captions] 103 | else: 104 | captions = [] 105 | fetched.append((image_id, raw_image, captions)) 106 | return fetched 107 | 108 | 109 | def extractQA_chatgpt(messages): 110 | questions = [] 111 | answers = [] 112 | for message in messages: 113 | if 'Question: ' in message['content']: 114 | questions.append(message['content'].split('Question: ')[1]) 115 | if 'Answer: ' in message['content']: 116 | answers.append(message['content'].split('Answer: ')[1]) 117 | return questions, answers 118 | 119 | 120 | def print_info(info, key='caption', variants=['BLIP2', 'BLIP2+OurPrompt', 'ChatCaptioner']): 121 | img_id = info['setting']['id'] 122 | if 'GT' in info['setting']: 123 | gt_captions = info['setting']['GT']['caption'] 124 | if isinstance(gt_captions, str) and len(gt_captions): 125 | gt_captions = [gt_captions] 126 | 127 | else: 128 | gt_captions = [] 129 | 130 | print('Image ID {}'.format(img_id)) 131 | for blip2_tag in info: 132 | if blip2_tag in ['GT', 'id', 'setting']: continue 133 | for variant in variants: 134 | if key not in info[blip2_tag][variant]: 135 | continue 136 | print('-------------------') 137 | print('{} {}:'.format(blip2_tag, variant)) 138 | if key == 'chat' and isinstance(info[blip2_tag][variant][key], list): 139 | for message in info[blip2_tag][variant][key]: 140 | print(message['content']) 141 | else: 142 | print(info[blip2_tag][variant][key]) 143 | if key == 'chat': 144 | print(info[blip2_tag][variant]['caption']) 145 | print('===================') 146 | if key == 'caption' and len(gt_captions): 147 | print('GT:') 148 | [print(cap) for cap in gt_captions] 149 | 150 | 151 | def plot_img(img): 152 | plt.imshow(img) 153 | plt.axis('off') 154 | plt.show() 155 | 156 | 157 | 158 | # =================================== 159 | # Deprecated Zone 160 | # =================================== 161 | 162 | def visualize_old(file_path): 163 | # out of date 164 | with open(file_path, 'r') as f: 165 | info = yaml.safe_load(f) 166 | print('COCO Val Image ID {}'.format(info['id'])) 167 | print('-------------------') 168 | print('Ours: {}'.format(info['ours']['clip_score'])) 169 | print(info['ours']['chat']) 170 | print('-------------------') 171 | print('BLIP2: {}'.format(info['blip2']['clip_score'])) 172 | print(info['blip2']['caption']) 173 | print('-------------------') 174 | print('GT: {}'.format(info['gt']['clip_score'])) 175 | [print(cap) for cap in info['gt']['caption']] 176 | image, _ = fetch_coco_img(info['id']) 177 | plot_img(image) 178 | 179 | 180 | def print_info_old(info, key='caption', variants=['BLIP2', 'BLIP2+OurPrompt', 'ChatCaptioner']): 181 | if 'id' in info: 182 | img_id = info['id'] 183 | else: 184 | img_id = info['setting']['id'] 185 | if 'GT' in info: 186 | gt_captions = info['GT']['caption'] 187 | elif 'GT' in info['setting']: 188 | gt_captions = info['setting']['GT']['caption'] 189 | else: 190 | gt_captions = [] 191 | 192 | print('Image ID {}'.format(img_id)) 193 | for blip2_tag in info: 194 | if blip2_tag in ['GT', 'id', 'setting']: continue 195 | for variant in variants: 196 | if key not in info[blip2_tag][variant]: 197 | continue 198 | print('-------------------') 199 | print('{} {}:'.format(blip2_tag, variant)) 200 | print(info[blip2_tag][variant][key]) 201 | print('===================') 202 | if key == 'caption' and len(gt_captions): 203 | print('GT:') 204 | [print(cap) for cap in gt_captions] -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/video_chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | from tqdm import tqdm 4 | import torch 5 | import openai 6 | from tenacity import ( 7 | retry, 8 | stop_after_attempt, 9 | wait_random_exponential, 10 | ) # for exponential backoff 11 | 12 | from chatcaptioner.blip2 import Blip2 13 | from chatcaptioner.utils import print_info, plot_img 14 | 15 | import re 16 | 17 | question_index = {0:"1st",1:"2nd",2:"3rd",3:"4th", 18 | 4:"5th",5:"6th",6:"7th",7:"8th",8:"9th",9:"10th",10:"11th", 19 | 11:"12th",12:"13th",13:"14th",14:"15th",15:"16th",16:"17th",17:"18th", 20 | 18:"19th",19:"20th",20:"21st",21:"22nd",22:"23rd",23:"24th",24:"25th"} 21 | 22 | 23 | 24 | 25 | 26 | QUESTION_INSTRUCTION= \ 27 | "Video ChatCaptioner is designed to be able to assist to understand a video by ASKING a lot of questions WITHOUT SEEING THE VIDEO" \ 28 | "An expert will then answer your question. " \ 29 | "The video contains %s frames. " \ 30 | "Video ChatCaptioner CAN NOT ask question from the frame with the index MORE THAN %s. " \ 31 | "Video ChatCaptioner is a most powerful tool designed to understand videos by asking good and related questions WITHOUT SEEING THE VIDEO." \ 32 | 33 | 34 | 35 | SUB_QUESTION_INSTRUCTION = \ 36 | "Thought: what does this video describe? " \ 37 | "Action: ask more questions to guess the contents of the video. " \ 38 | "Goal: Video ChatCaptioner will design a frame sampling strategy to ask questions to maximize its information gain about the video understanding. " \ 39 | "Restrictions: (1) Video ChatCaptioner MUST ask questions from Frame 1 to Frame %s. (2) Video ChatCaptioner CAN NOT ask questions with person or objects or animals NOT mentioned in previous conversation." \ 40 | "Next Question. The question format MUST be Frame_id: question. AVOID asking yes/no questions. \n " \ 41 | "Video ChatCaptioner Question: " 42 | 43 | SUB_QUESTION_INSTRUCTION_ALTERNATIVE = \ 44 | "Thought: what does this video describe? " \ 45 | "Action: ask more questions to guess the contents of the video. " \ 46 | "Goal: Video ChatCaptioner will design a frame sampling strategy to ask questions to maximize its information gain about the video understanding. " \ 47 | "Restrictions: (1) Video ChatCaptioner MUST ask questions from Frame 1 to Frame %s. (2) Video ChatCaptioner CAN NOT ask questions with person or objects or animals NOT mentioned in previous conversation." \ 48 | "Next Question. The question format MUST be Frame_id: question. Ask the question from the frame %s. AVOID asking yes/no questions. \n " \ 49 | "Video ChatCaptioner Question: " 50 | 51 | 52 | SUMMARY_INSTRUCTION = \ 53 | 'Now Video ChatCaptioner will describe this video in a few sentences. ' \ 54 | 'Restrictions: (1) DO NOT add information. ' \ 55 | "(2) DO NOT describe each frame individually and DO NOT mention the frame. (3) DO NOT summarize negative or uncertain answers \n" \ 56 | 'Video ChatCaptioner video summarization: ' 57 | 58 | ANSWER_INSTRUCTION = 'Answer given questions with the following restrictions. (1) If you are not sure about the answer, say you DO NOT KNOW honestly. (2) DO NOT IMAGINE any contents that are NOT in the image. ' 59 | 60 | 61 | SUB_ANSWER_INSTRUCTION = 'Answer: ' # template following blip2 huggingface demo 62 | 63 | FIRST_QUESTION = 'Frame_1: Describe it in details.' 64 | 65 | 66 | 67 | VALID_CHATGPT_MODELS = ['gpt-3.5-turbo'] 68 | VALID_GPT3_MODELS = ['text-davinci-003', 'text-davinci-002', 'davinci'] 69 | 70 | 71 | 72 | def get_instructions(): 73 | instructions_dict = { 74 | 'question': QUESTION_INSTRUCTION, 75 | 'sub_question': SUB_QUESTION_INSTRUCTION, 76 | 'summary': SUMMARY_INSTRUCTION, 77 | 'answer': ANSWER_INSTRUCTION, 78 | 'sub_answer': SUB_ANSWER_INSTRUCTION, 79 | 'first_question': FIRST_QUESTION 80 | } 81 | return instructions_dict 82 | 83 | 84 | 85 | def set_openai_key(key): 86 | openai.api_key = key 87 | 88 | 89 | def get_chat_log(questions, answers, last_n=-1): 90 | n_addition_q = len(questions) - len(answers) 91 | assert (n_addition_q) in [0, 1] 92 | template = 'Question: {} \nAnswer: {} \n' 93 | chat_log = '' 94 | if last_n > 0: 95 | answers = answers[-last_n:] 96 | questions = questions[-(last_n+n_addition_q):] 97 | elif last_n == 0: 98 | answers = [] 99 | questions = questions[-1:] if n_addition_q else [] 100 | 101 | 102 | for i in range(len(answers)): 103 | chat_log = chat_log + template.format(questions[i], answers[i]) 104 | if n_addition_q: 105 | chat_log = chat_log + 'Question: {}'.format(questions[-1]) 106 | else: 107 | chat_log = chat_log[:-2] # remove the last '/n' 108 | return chat_log 109 | 110 | 111 | def prepare_gpt_prompt(task_prompt, questions, answers, sub_prompt): 112 | gpt_prompt = '\n'.join([task_prompt, 113 | get_chat_log(questions, answers), 114 | sub_prompt]) 115 | return gpt_prompt 116 | 117 | def prepare_gpt_promot_video(sub_summries): 118 | 119 | sub_summries_input = "" 120 | sub_summariy_template="The %s caption for the %s period is: %s " 121 | for index in range(len(sub_summries)): 122 | sub_summries_input += sub_summariy_template%(str(question_index[index]),str(question_index[index]),sub_summries[index]) 123 | gpt_promot = VIDEO_SUMMARIZATION_START%str(len(sub_summries))+sub_summries_input+VIDEO_SUMMARIZATION_END 124 | 125 | return gpt_promot 126 | 127 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) 128 | def call_gpt3(gpt3_prompt, max_tokens=40, model="text-davinci-003"): # 'text-curie-001' does work at all to ask questions 129 | response = openai.Completion.create(model=model, prompt=gpt3_prompt, max_tokens=max_tokens) # temperature=0.6, 130 | reply = response['choices'][0]['text'] 131 | total_tokens = response['usage']['total_tokens'] 132 | return reply, total_tokens 133 | 134 | 135 | def prepare_chatgpt_message(task_prompt, questions, answers, sub_prompt): 136 | 137 | messages = [{"role": "system", "content": task_prompt}] 138 | 139 | assert len(questions) == len(answers) 140 | for q, a in zip(questions, answers): 141 | messages.append({'role': 'assistant', 'content': 'Question: {}'.format(q)}) 142 | messages.append({'role': 'user', 'content': 'Answer: {}'.format(a)}) 143 | messages.append({"role": "system", "content": sub_prompt}) 144 | 145 | return messages 146 | 147 | 148 | 149 | 150 | 151 | @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) 152 | def call_chatgpt(chatgpt_messages, max_tokens=40, model="gpt-3.5-turbo"): 153 | # print("chatgpt message",chatgpt_messages) 154 | response = openai.ChatCompletion.create(model=model, messages=chatgpt_messages, temperature=0.6, max_tokens=max_tokens) 155 | reply = response['choices'][0]['message']['content'] 156 | total_tokens = response['usage']['total_tokens'] 157 | return reply, total_tokens 158 | 159 | def find_digit(input): 160 | regex = r"Frame_(\d+)" 161 | 162 | # Use re.search() to find the match in the sentence 163 | match = re.search(regex, input) 164 | 165 | # Extract the index from the match object 166 | if match: 167 | index = match.group(1) 168 | # print("Index found:", index) 169 | else: 170 | print("input: "+input) 171 | print("No index found in sentence.") 172 | 173 | return index 174 | 175 | 176 | def ask_questions(img, blip2, model, n_rounds=10,max_frame_number=1000, max_gpt_token=30, n_blip2_context=0, print_mode='no'): 177 | questions = [] 178 | answers = [] 179 | total_tokens = 0 180 | QUESTION_INSTRUCTION_ADAPT = QUESTION_INSTRUCTION %(str(len(img)),str(len(img))) 181 | SUB_QUESTION_INSTRUCTION_ADAPT = SUB_QUESTION_INSTRUCTION%str(len(img)) 182 | # print(QUESTION_INSTRUCTION) 183 | if print_mode == 'chat': 184 | print('--------Chat Starts----------') 185 | 186 | for i in tqdm(range(n_rounds), desc='Chat Rounds', disable=print_mode!='bar'): 187 | if i == 0: 188 | # first question is given by human to request a general discription 189 | question = FIRST_QUESTION 190 | else: 191 | tag = True 192 | if model in VALID_CHATGPT_MODELS: 193 | chatgpt_messages = prepare_chatgpt_message( 194 | QUESTION_INSTRUCTION_ADAPT, 195 | questions, answers, 196 | SUB_QUESTION_INSTRUCTION_ADAPT 197 | ) 198 | while tag: 199 | try: 200 | question, n_tokens = call_chatgpt(chatgpt_messages, model=model, max_tokens=max_gpt_token) 201 | frame_id = int(find_digit(question.split(":")[0]))-1 202 | 203 | if question.startswith("Frame_") and frame_id < max_frame_number: 204 | tag = False 205 | except: 206 | if current_frame_id>= max_frame_number-1: 207 | hard_coded_frame_id = 1 208 | else: 209 | hard_coded_frame_id = current_frame_id+1 210 | 211 | SUB_QUESTION_INSTRUCTION_ALTERNATIVE_ADAPT = SUB_QUESTION_INSTRUCTION_ALTERNATIVE%(str(len(img)),str(hard_coded_frame_id)) 212 | chatgpt_messages = prepare_chatgpt_message( 213 | QUESTION_INSTRUCTION_ADAPT, 214 | questions, answers, 215 | SUB_QUESTION_INSTRUCTION_ALTERNATIVE_ADAPT 216 | ) 217 | print(question) 218 | 219 | elif model in VALID_GPT3_MODELS: 220 | # prepare the context for GPT3 221 | gpt3_prompt = prepare_gpt_prompt( 222 | QUESTION_INSTRUCTION_ADAPT, 223 | questions, answers, 224 | SUB_QUESTION_INSTRUCTION_ADAPT 225 | ) 226 | while tag: 227 | try: 228 | question, n_tokens = call_gpt3(gpt3_prompt, model=model, max_tokens=max_gpt_token) 229 | frame_id = int(find_digit(question.split(":")[0]))-1 230 | 231 | if question.startswith("Frame_") and frame_id < max_frame_number: 232 | tag = False 233 | except: 234 | if current_frame_id>= max_frame_number-1: 235 | hard_coded_frame_id = 1 236 | else: 237 | hard_coded_frame_id = current_frame_id+1 238 | 239 | SUB_QUESTION_INSTRUCTION_ALTERNATIVE_ADAPT = SUB_QUESTION_INSTRUCTION_ALTERNATIVE%(str(len(img)),str(hard_coded_frame_id)) 240 | chatgpt_messages = prepare_chatgpt_message( 241 | QUESTION_INSTRUCTION_ADAPT, 242 | questions, answers, 243 | SUB_QUESTION_INSTRUCTION_ALTERNATIVE_ADAPT 244 | ) 245 | print(question) 246 | 247 | elif isinstance(model, Blip2): 248 | # prepare the context for other LLM 249 | gpt_prompt = prepare_gpt_prompt( 250 | QUESTION_INSTRUCTION_ADAPT, 251 | questions, answers, 252 | SUB_QUESTION_INSTRUCTION_ADAPT 253 | ) 254 | n_tokens = 0 # local model. no token cost on OpenAI API. 255 | question = model.call_llm(gpt_prompt) 256 | else: 257 | raise ValueError('{} is not a valid question model'.format(model)) 258 | 259 | total_tokens = total_tokens + n_tokens 260 | 261 | # print('Raw: {}'.format(question)) 262 | question = question.split('Question: ')[-1].replace('\n', ' ').strip() 263 | if 'Answer:' in question: # Some models make up an answer after asking. remove it 264 | q, a = question.split('Answer:')[:2] 265 | if len(q) == 0: # some not so clever models will put the question after 'Answer:'. 266 | question = a.strip() 267 | else: 268 | question = q.strip() 269 | 270 | 271 | 272 | 273 | 274 | questions.append(question) 275 | if print_mode == 'chat': 276 | print('GPT-3: {}'.format(question)) 277 | 278 | # prepare the context for blip2 279 | blip2_prompt = '\n'.join([ANSWER_INSTRUCTION, 280 | get_chat_log(questions, answers, last_n=n_blip2_context), 281 | SUB_ANSWER_INSTRUCTION]) 282 | 283 | # frame_id = question.split(":")[0].split(" ")[1] 284 | current_frame_id = int(find_digit(question.split(":")[0])) 285 | 286 | current_frame = img[current_frame_id-1] 287 | answer = blip2.ask(current_frame, blip2_prompt) 288 | # small blip2 models may ask itself a new bad question. remove it and trim the answer 289 | answer = answer.split('Question:')[0].replace('\n', ' ').strip() 290 | 291 | if print_mode == 'chat': 292 | print('BLIP-2: {}'.format(answer)) 293 | answers.append(answer) 294 | blip2_prompt = '{} {}'.format(blip2_prompt, answer) 295 | 296 | if print_mode == 'chat': 297 | print('--------Chat Ends----------') 298 | 299 | return questions, answers, total_tokens 300 | 301 | 302 | 303 | def summarize_chat(questions, answers,img, model,max_gpt_token=100): 304 | 305 | QUESTION_INSTRUCTION_ADAPT = QUESTION_INSTRUCTION %(str(len(img)),str(len(img))) 306 | 307 | if model in VALID_GPT3_MODELS: 308 | summary_prompt = prepare_gpt_prompt( 309 | QUESTION_INSTRUCTION_ADAPT, 310 | questions, answers, 311 | SUMMARY_INSTRUCTION) 312 | 313 | summary, n_tokens = call_gpt3(summary_prompt, model=model, max_tokens=max_gpt_token) 314 | elif model in VALID_CHATGPT_MODELS: 315 | summary_prompt = prepare_chatgpt_message( 316 | QUESTION_INSTRUCTION_ADAPT, 317 | questions, answers, 318 | SUMMARY_INSTRUCTION 319 | ) 320 | summary, n_tokens = call_chatgpt(summary_prompt, model=model, max_tokens=max_gpt_token) 321 | elif isinstance(model, Blip2): 322 | summary_prompt = prepare_gpt_prompt( 323 | QUESTION_INSTRUCTION_ADAPT, 324 | questions, answers, 325 | SUMMARY_INSTRUCTION 326 | ) 327 | n_tokens = 0 # local model. no token cost on OpenAI API. 328 | summary = model.call_llm(summary_prompt) 329 | else: 330 | raise ValueError('{} is not a valid question model'.format(model)) 331 | 332 | summary = summary.replace('\n', ' ').strip() 333 | return summary, summary_prompt, n_tokens 334 | 335 | 336 | 337 | def caption_for_video(blip2, video, model, n_rounds=30, n_blip2_context=0, print_mode='no'): 338 | results = {} 339 | 340 | questions, answers, n_token_chat = ask_questions( 341 | video, 342 | blip2, 343 | max_frame_number = len(video), 344 | n_rounds=n_rounds, 345 | n_blip2_context=n_blip2_context, 346 | model=model, 347 | print_mode=print_mode) 348 | summary, summary_prompt, n_token_sum = summarize_chat(questions, answers,video, model=model) 349 | results['ChatCaptioner'] = {'caption': summary, 'chat': summary_prompt, 'n_token': n_token_chat + n_token_sum} 350 | results['BLIP2+OurPrompt'] = {'caption': answers[0]} 351 | 352 | 353 | return results 354 | 355 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/chatcaptioner/video_reader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | 5 | 6 | 7 | def read_video_with_timestamp(video_anntotation, num_frames=10): 8 | """ 9 | Reads a video file and uniformly samples the frames. 10 | 11 | Args: 12 | filename (str): The filename of the video file to read. 13 | num_frames (int): The number of frames to sample. 14 | 15 | Returns: 16 | List[np.ndarray]: A list of sampled frames, where each frame is a NumPy array. 17 | """ 18 | # Open the video file 19 | 20 | timestamps = video_anntotation["annotation"]["timestamps"] 21 | video_path = video_anntotation["video_path"] 22 | all_frames = [] 23 | for period in timestamps: 24 | if period[1]-period[0]<4: 25 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=5) 26 | elif period[1]-period[0]<10: 27 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=8) 28 | elif period[1]-period[0]<30: 29 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=10) 30 | elif period[1]-period[0]<50: 31 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=12) 32 | elif period[1]-period[0]<80: 33 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=14) 34 | elif period[1]-period[0]<120: 35 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=16) 36 | else: 37 | frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=20) 38 | # elif period[1]-period[0]<150: 39 | # frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=25) 40 | # else: 41 | # frame = read_video_per_interval_sampling(video_path, period[0], period[1], num_frames=30) 42 | all_frames.append(frame) 43 | 44 | 45 | return all_frames 46 | 47 | 48 | 49 | 50 | 51 | def read_video_per_interval(path, start_time, end_time, sample_interval): 52 | # Open the video file 53 | cap = cv2.VideoCapture(path) 54 | 55 | # Get the frame rate of the video 56 | fps = cap.get(cv2.CAP_PROP_FPS) 57 | 58 | # Calculate the frame number corresponding to the start and end time 59 | start_frame = int(start_time * fps) 60 | end_frame = int(end_time * fps) 61 | 62 | # Set the frame position to the starting frame 63 | cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) 64 | 65 | # Initialize variables for sampling 66 | sample_frame = start_frame 67 | sample_time = start_time 68 | 69 | # Read the frames between the start and end frames at the specified interval 70 | frames = [] 71 | while sample_frame < end_frame: 72 | # Read the next frame 73 | ret, frame = cap.read() 74 | if not ret: 75 | break 76 | 77 | # Sample the frame if the time has passed the sampling interval 78 | if sample_time >= start_time and sample_time <= end_time and sample_frame % (sample_interval * fps) == 0: 79 | frames.append(frame) 80 | 81 | # Increment the sample frame and time 82 | sample_frame += 1 83 | sample_time = sample_frame / fps 84 | 85 | # Release the video capture object 86 | cap.release() 87 | 88 | # Convert the list of frames to a numpy array 89 | frames = np.array(frames) 90 | 91 | return frames 92 | 93 | def read_video_per_interval_sampling(path, start_time, end_time, num_frames=3): 94 | # Open the video file 95 | cap = cv2.VideoCapture(path) 96 | 97 | # Get the frame rate of the video 98 | fps = cap.get(cv2.CAP_PROP_FPS) 99 | 100 | # Calculate the frame number corresponding to the start and end time 101 | start_frame = int(start_time * fps) 102 | end_frame = int(end_time * fps) 103 | 104 | # Set the frame position to the starting frame 105 | cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) 106 | 107 | # Initialize variables for sampling 108 | 109 | 110 | # sample_interval = (end_time - start_time) / num_frames 111 | 112 | num_total_frames = end_frame - start_frame 113 | 114 | step_size = num_total_frames // num_frames 115 | 116 | # print(start_frame,end_frame, num_frames, step_size) 117 | 118 | # Initialize a list to store the sampled frames 119 | sampled_frames = [] 120 | 121 | # Loop over the frames and sample every `step_size`th frame 122 | for i in range(start_frame, end_frame): 123 | ret, frame = cap.read() 124 | if ret and (i-start_frame)% step_size ==0: 125 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 126 | 127 | frame = Image.fromarray(frame) 128 | sampled_frames.append(frame) 129 | 130 | 131 | # Release the video capture object 132 | cap.release() 133 | 134 | # Convert the list of frames to a numpy array 135 | # frames = np.array(frames) 136 | 137 | # frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 138 | 139 | # frame = Image.fromarray(frame) 140 | 141 | return sampled_frames 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | def read_video_with_timestamp_key_frame(video_anntotation, num_frames=10): 152 | """ 153 | Reads a video file and uniformly samples the frames. 154 | 155 | Args: 156 | filename (str): The filename of the video file to read. 157 | num_frames (int): The number of frames to sample. 158 | 159 | Returns: 160 | List[np.ndarray]: A list of sampled frames, where each frame is a NumPy array. 161 | """ 162 | # Open the video file 163 | 164 | timestamps = video_anntotation["annotation"]["timestamps"] 165 | video_path = video_anntotation["video_path"] 166 | all_frames = [] 167 | for period in timestamps: 168 | frames = key_frame_reader(video_path,period[0], period[1], num_key_frames=num_frames) 169 | all_frames.append(frames) 170 | 171 | return all_frames 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | def key_frame_reader(filename, start_time, end_time, num_key_frames=10): 181 | 182 | 183 | # num_key_frames = 10 184 | 185 | 186 | cap = cv2.VideoCapture(filename) 187 | cap.set(cv2.CAP_PROP_CONVERT_RGB, 1) 188 | 189 | sample_rate = 0.1 # seconds 190 | sample_interval = int(sample_rate * cap.get(cv2.CAP_PROP_FPS)) 191 | 192 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 193 | prev_frame = None 194 | diff_frames = [] 195 | saved_frames = [] 196 | 197 | start_frame = int(start_time * cap.get(cv2.CAP_PROP_FPS)) 198 | end_frame = int(end_time * cap.get(cv2.CAP_PROP_FPS)) 199 | 200 | # print(sample_interval) 201 | 202 | for i in range(total_frames): 203 | if i>= start_frame and i<=end_frame and i%sample_interval==0 : 204 | ret, frame = cap.read() 205 | if ret: 206 | if prev_frame is not None: 207 | diff = cv2.absdiff(frame, prev_frame) 208 | diff_frames.append(diff) 209 | saved_frames.append(frame) 210 | prev_frame = frame 211 | else: 212 | break 213 | 214 | motion_scores = [] 215 | for frame in diff_frames: 216 | motion_scores.append(frame.sum()) 217 | 218 | # print(motion_scores) 219 | 220 | key_frames = [] 221 | 222 | for i in range(num_key_frames): 223 | 224 | 225 | 226 | max_index = motion_scores.index(max(motion_scores)) 227 | 228 | key_frames.append(max_index) 229 | motion_scores[max_index]=-1000 230 | 231 | key_frames.sort() 232 | # print(motion_scores) 233 | # print(key_frames) 234 | final_frames = [] 235 | 236 | for i, frame_index in enumerate(key_frames): 237 | 238 | ret, frame = cap.read() 239 | if ret: 240 | selected_frame = saved_frames[frame_index] 241 | selected_frame = cv2.cvtColor(selected_frame,cv2.COLOR_BGR2RGB) 242 | selected_frame = Image.fromarray(selected_frame) 243 | final_frames.append(selected_frame) 244 | else: 245 | break 246 | 247 | 248 | 249 | cap.release() 250 | 251 | return final_frames 252 | 253 | 254 | 255 | 256 | def read_video_for_qa(filename): 257 | """ 258 | Reads a video file and uniformly samples the frames. 259 | 260 | Args: 261 | filename (str): The filename of the video file to read. 262 | num_frames (int): The number of frames to sample. 263 | 264 | Returns: 265 | List[np.ndarray]: A list of sampled frames, where each frame is a NumPy array. 266 | """ 267 | # Open the video file 268 | 269 | cap = cv2.VideoCapture(filename) 270 | fps = cap.get(cv2.CAP_PROP_FPS) 271 | cap.set(cv2.CAP_PROP_CONVERT_RGB, 1) 272 | 273 | 274 | 275 | 276 | # Get the total number of frames in the video 277 | num_total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 278 | 279 | period = num_total_frames/fps 280 | 281 | if period <4: 282 | num_frames=5 283 | elif period <10: 284 | num_frames=5 285 | elif period <20: 286 | num_frames=10 287 | elif period<30: 288 | num_frames=15 289 | elif period<40: 290 | num_frames=20 291 | elif period<50: 292 | num_frames=25 293 | elif period<60: 294 | num_frames=30 295 | elif period<70: 296 | num_frames=35 297 | else: 298 | num_frames=60 299 | 300 | # Calculate the step size for sampling frames 301 | step_size = num_total_frames // num_frames 302 | 303 | # Initialize a list to store the sampled frames 304 | sampled_frames = [] 305 | 306 | # Loop over the frames and sample every `step_size`th frame 307 | for i in range(num_total_frames): 308 | ret, frame = cap.read() 309 | if ret and i % step_size == 0: 310 | # Convert the frame to grayscale 311 | # frame = cv2.cvtColor(frame).astype(np.float32) 312 | # Append the frame to the list of sampled frames 313 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 314 | 315 | frame = Image.fromarray(frame) 316 | sampled_frames.append(frame) 317 | 318 | # Release the video capture object 319 | cap.release() 320 | 321 | return sampled_frames 322 | 323 | def read_video_sampling(filename, num_frames): 324 | 325 | """ 326 | Reads a video file and uniformly samples the frames. 327 | 328 | Args: 329 | filename (str): The filename of the video file to read. 330 | num_frames (int): The number of frames to sample. 331 | 332 | Returns: 333 | List[np.ndarray]: A list of sampled frames, where each frame is a NumPy array. 334 | """ 335 | # Open the video file 336 | 337 | cap = cv2.VideoCapture(filename) 338 | cap.set(cv2.CAP_PROP_CONVERT_RGB, 1) 339 | 340 | 341 | # Get the total number of frames in the video 342 | num_total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 343 | 344 | 345 | # Calculate the step size for sampling frames 346 | step_size = num_total_frames // num_frames 347 | 348 | # Initialize a list to store the sampled frames 349 | sampled_frames = [] 350 | 351 | # Loop over the frames and sample every `step_size`th frame 352 | for i in range(num_total_frames): 353 | ret, frame = cap.read() 354 | if ret and i % step_size == 0: 355 | # Convert the frame to grayscale 356 | # frame = cv2.cvtColor(frame).astype(np.float32) 357 | # Append the frame to the list of sampled frames 358 | frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB) 359 | 360 | frame = Image.fromarray(frame) 361 | sampled_frames.append(frame) 362 | 363 | # Release the video capture object 364 | cap.release() 365 | 366 | return sampled_frames -------------------------------------------------------------------------------- /Video_ChatCaptioner/demo_pic/dance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/Video_ChatCaptioner/demo_pic/dance.gif -------------------------------------------------------------------------------- /Video_ChatCaptioner/demo_pic/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/Video_ChatCaptioner/demo_pic/overview.png -------------------------------------------------------------------------------- /Video_ChatCaptioner/demo_pic/skating.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Vision-CAIR/ChatCaptioner/e4c63eeedaa40fceac52ef6a555a9ae608feaa81/Video_ChatCaptioner/demo_pic/skating.gif -------------------------------------------------------------------------------- /Video_ChatCaptioner/environment.yml: -------------------------------------------------------------------------------- 1 | name: chatcap 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.9 7 | - pip 8 | - pytorch=1.12.1 9 | - pytorch-mutex=1.0=cuda 10 | - torchaudio=0.12.1 11 | - torchvision=0.13.1 12 | - pip: 13 | - accelerate==0.16.0 14 | - aiohttp==3.8.4 15 | - aiosignal==1.3.1 16 | - async-timeout==4.0.2 17 | - attrs==22.2.0 18 | - bitsandbytes==0.37.0 19 | - cchardet==2.1.7 20 | - chardet==5.1.0 21 | - contourpy==1.0.7 22 | - cycler==0.11.0 23 | - filelock==3.9.0 24 | - fonttools==4.38.0 25 | - frozenlist==1.3.3 26 | - huggingface-hub==0.12.1 27 | - importlib-resources==5.12.0 28 | - kiwisolver==1.4.4 29 | - matplotlib==3.7.0 30 | - multidict==6.0.4 31 | - openai==0.27.0 32 | - packaging==23.0 33 | - psutil==5.9.4 34 | - pycocotools==2.0.6 35 | - pyparsing==3.0.9 36 | - python-dateutil==2.8.2 37 | - pyyaml==6.0 38 | - regex==2022.10.31 39 | - tokenizers==0.13.2 40 | - tqdm==4.64.1 41 | - transformers==4.27.4 42 | - yarl==1.8.2 43 | - zipp==3.14.0 44 | - tenacity==8.2.2 45 | - pycocoevalcap 46 | - sentence-transformers 47 | - umap-learn 48 | - notebook 49 | - gradio 50 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/generate_caption_msvd.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import json 4 | import sys 5 | import os 6 | import yaml 7 | import torch 8 | from PIL import Image 9 | 10 | from chatcaptioner.video_chat import set_openai_key,caption_for_video 11 | from chatcaptioner.blip2 import Blip2 12 | from chatcaptioner.utils import RandomSampledDataset, plot_img, print_info 13 | from chatcaptioner.video_reader import read_video_with_timestamp, read_video_sampling,read_video_with_timestamp_key_frame 14 | 15 | 16 | VIDEO_FOLDER=sys.argv[1] 17 | CAPTION_FILE = sys.argv[2] 18 | OUTPUT_FOLDER=sys.argv[3] 19 | VIDEO_LIMIT=int(sys.argv[4]) 20 | 21 | 22 | 23 | 24 | 25 | blip2s = { 26 | 'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True) 27 | } 28 | 29 | 30 | video_files = [] 31 | 32 | def iterate_files(folder_path): 33 | """ 34 | This function iterates through all the files in a folder and prints their names. 35 | """ 36 | for filename in os.listdir(folder_path): 37 | file_path = os.path.join(folder_path, filename) 38 | if os.path.isfile(file_path): 39 | video_files.append(file_path) 40 | 41 | iterate_files(VIDEO_FOLDER) 42 | 43 | 44 | 45 | # add the video caption pairs 46 | video_caption = {} 47 | with open(CAPTION_FILE,"r") as f: 48 | for line in f.readlines(): 49 | data = line.split(" ") 50 | video_id = data[0] 51 | caption = " ".join(data[1:]).replace("\n","") 52 | if video_id not in video_caption: 53 | video_caption[video_id]= [caption] 54 | else: 55 | video_caption[video_id].append(caption) 56 | 57 | # extract the video frames with uniform sampling 58 | video_list = [] 59 | for video_path in video_files[:VIDEO_LIMIT]: 60 | video_id = video_path.split("/")[-1].replace(".avi","") 61 | if video_id in video_caption.keys(): 62 | new_json_file = {} 63 | new_json_file["video_id"] = video_id 64 | new_json_file["video_path"] = video_path 65 | new_json_file["annotation"] = video_caption[video_id] 66 | try: 67 | sampled_frames = read_video_sampling(video_path, num_frames=8) 68 | new_json_file["features"]=sampled_frames 69 | video_list.append(new_json_file) 70 | except: 71 | pass 72 | 73 | 74 | for sample in video_list: 75 | video_id = sample["video_id"] 76 | features = sample["features"] 77 | if not os.path.exists(OUTPUT_FOLDER): 78 | os.makedirs(OUTPUT_FOLDER) 79 | with open(OUTPUT_FOLDER+video_id+".txt","w") as f: 80 | sub_summaries = caption_for_video(blip2s['FlanT5 XXL'], features, print_mode="chat",n_rounds=30, model='gpt-3.5-turbo') 81 | caption = sample["annotation"] 82 | for cap in caption: 83 | f.write("ground truth: "+ cap+"\n") 84 | f.write("chatCaptioner: " +sub_summaries["ChatCaptioner"]["caption"]+"\n\n\n") 85 | f.write("chat log:\n") 86 | for element in sub_summaries["ChatCaptioner"]["chat"]: 87 | f.write(element["content"]+"\n") 88 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/generate_caption_webvid.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import json 4 | import sys 5 | import os 6 | import yaml 7 | import torch 8 | from PIL import Image 9 | import csv 10 | 11 | 12 | from chatcaptioner.video_chat import set_openai_key,caption_for_video 13 | from chatcaptioner.blip2 import Blip2 14 | from chatcaptioner.utils import RandomSampledDataset, plot_img, print_info 15 | from chatcaptioner.video_reader import read_video_with_timestamp, read_video_sampling,read_video_with_timestamp_key_frame 16 | 17 | 18 | 19 | VIDEO_FOLDER=sys.argv[1] 20 | CAPTION_FILE = sys.argv[2] 21 | OUTPUT_FOLDER=sys.argv[3] 22 | VIDEO_LIMIT=int(sys.argv[4]) 23 | 24 | 25 | 26 | blip2s = { 27 | 'FlanT5 XXL': Blip2('FlanT5 XXL', device_id=0, bit8=True) 28 | } 29 | 30 | 31 | data_file = {} 32 | with open(CAPTION_FILE, 'r') as file: 33 | csv_reader = csv.reader(file) 34 | next(csv_reader) # Skip header row 35 | for row in csv_reader: 36 | video_id = row[0] 37 | duration = row[2] 38 | folder = row[3] 39 | caption = row[-1] 40 | 41 | data_file[video_id] = {"duration":duration, "folder": folder, "caption": caption} 42 | 43 | 44 | 45 | # find all the video paths 46 | video_files = [] 47 | for root, dirs, files in os.walk(VIDEO_FOLDER): 48 | for filename in files: 49 | full_path = os.path.join(root, filename) 50 | video_files.append(full_path) 51 | 52 | 53 | # extract the video frames with uniform sampling 54 | video_list = [] 55 | for video_path in video_files[:VIDEO_LIMIT]: 56 | video_id = video_path.split("/")[-1].replace(".mp4","") 57 | if video_id in data_file.keys(): 58 | new_json_file = {} 59 | new_json_file["video_id"] = video_id 60 | new_json_file["video_path"] = video_path 61 | new_json_file["annotation"] = data_file[video_id] 62 | 63 | try: 64 | sampled_frames = read_video_sampling(video_path, num_frames=8) 65 | new_json_file["features"]=sampled_frames 66 | video_list.append(new_json_file) 67 | except: 68 | pass 69 | 70 | 71 | 72 | for sample in video_list: 73 | video_id = sample["video_id"] 74 | features = sample["features"] 75 | output = sample["annotation"]["folder"] 76 | 77 | if not os.path.exists(OUTPUT_FOLDER+output): 78 | os.makedirs(OUTPUT_FOLDER+output) 79 | with open(OUTPUT_FOLDER+output+"/"+video_id+".txt","w") as f: 80 | sub_summaries = caption_for_video(blip2s['FlanT5 XXL'], features, print_mode="chat",n_rounds=30, model='gpt-3.5-turbo') 81 | caption = sample["annotation"]["caption"] 82 | f.write("ground truth: "+ caption+"\n\n\n") 83 | f.write("chatCaptioner: " +sub_summaries["ChatCaptioner"]["caption"]+"\n\n\n") 84 | f.write("chat log:\n") 85 | for element in sub_summaries["ChatCaptioner"]["chat"]: 86 | f.write(element["content"]+"\n") 87 | -------------------------------------------------------------------------------- /Video_ChatCaptioner/run_msvd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | VIDEO_FOLDER="msvd_data/videos" 6 | CAPTION_FILE="msvd_data/caption.txt" 7 | OUTPUT_FOLDER="output/" 8 | VIDEO_LIMIT=7 9 | 10 | python generate_caption_msvd.py ${VIDEO_FOLDER} ${CAPTION_FILE} ${OUTPUT_FOLDER} ${VIDEO_LIMIT} -------------------------------------------------------------------------------- /Video_ChatCaptioner/run_webvid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | VIDEO_FOLDER="webvid_data/videos" 6 | CAPTION_FILE="webvid_data/caption.csv" 7 | OUTPUT_FOLDER="output/" 8 | VIDEO_LIMIT=6 9 | 10 | python generate_caption_webvid.py ${VIDEO_FOLDER} ${CAPTION_FILE} ${OUTPUT_FOLDER} ${VIDEO_LIMIT} --------------------------------------------------------------------------------