├── Evaluation ├── combine_preds.py ├── infer_utils.py ├── llava │ ├── __init__.py │ ├── constants.py │ ├── conversation.py │ ├── mm_utils.py │ ├── model │ │ ├── __init__.py │ │ ├── apply_delta.py │ │ ├── builder.py │ │ ├── consolidate.py │ │ ├── eval │ │ │ ├── eval_gpt_review.py │ │ │ ├── eval_gpt_review_bench.py │ │ │ ├── eval_gpt_review_visual.py │ │ │ ├── eval_pope.py │ │ │ ├── eval_science_qa.py │ │ │ ├── eval_science_qa_gpt4.py │ │ │ ├── eval_science_qa_gpt4_requery.py │ │ │ ├── eval_textvqa.py │ │ │ ├── generate_webpage_data_from_table.py │ │ │ ├── m4c_evaluator.py │ │ │ ├── model_qa.py │ │ │ ├── model_vqa.py │ │ │ ├── model_vqa_loader.py │ │ │ ├── model_vqa_mmbench.py │ │ │ ├── model_vqa_science.py │ │ │ ├── qa_baseline_gpt35.py │ │ │ ├── run_llava.py │ │ │ ├── summarize_gpt_review.py │ │ │ ├── table │ │ │ │ ├── answer │ │ │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ │ │ ├── answer_bard.jsonl │ │ │ │ │ ├── answer_gpt35.jsonl │ │ │ │ │ ├── answer_llama-13b.jsonl │ │ │ │ │ └── answer_vicuna-13b.jsonl │ │ │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ │ │ ├── model.jsonl │ │ │ │ ├── prompt.jsonl │ │ │ │ ├── question.jsonl │ │ │ │ ├── review │ │ │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ │ │ ├── reviewer.jsonl │ │ │ │ └── rule.json │ │ │ └── webpage │ │ │ │ ├── figures │ │ │ │ ├── alpaca.png │ │ │ │ ├── bard.jpg │ │ │ │ ├── chatgpt.svg │ │ │ │ ├── llama.jpg │ │ │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ │ │ └── vicuna.jpeg │ │ │ │ ├── index.html │ │ │ │ ├── script.js │ │ │ │ └── styles.css │ │ ├── language_model │ │ │ ├── llava_llama.py │ │ │ ├── llava_mpt.py │ │ │ └── mpt │ │ │ │ ├── adapt_tokenizer.py │ │ │ │ ├── attention.py │ │ │ │ ├── blocks.py │ │ │ │ ├── configuration_mpt.py │ │ │ │ ├── custom_embedding.py │ │ │ │ ├── flash_attn_triton.py │ │ │ │ ├── hf_prefixlm_converter.py │ │ │ │ ├── meta_init_context.py │ │ │ │ ├── modeling_mpt.py │ │ │ │ ├── norm.py │ │ │ │ └── param_init_fns.py │ │ ├── llava_arch.py │ │ ├── make_delta.py │ │ ├── multimodal_encoder │ │ │ ├── builder.py │ │ │ └── clip_encoder.py │ │ ├── multimodal_projector │ │ │ └── builder.py │ │ └── utils.py │ ├── serve │ │ ├── __init__.py │ │ ├── cli.py │ │ ├── controller.py │ │ ├── examples │ │ │ ├── extreme_ironing.jpg │ │ │ └── waterview.jpg │ │ ├── gradio_web_server.py │ │ ├── model_worker.py │ │ ├── register_worker.py │ │ └── test_message.py │ ├── train │ │ ├── llama_flash_attn_monkey_patch.py │ │ ├── llava_trainer.py │ │ ├── train.py │ │ └── train_mem.py │ └── utils.py ├── scripts │ ├── videochatgpt_pipeline.sh │ └── zeroshotqa_pipeline.sh ├── videochatgpt │ ├── evaluate_activitynet_qa.py │ ├── evaluate_benchmark.sh │ ├── evaluate_benchmark_1_correctness.py │ ├── evaluate_benchmark_2_detailed_orientation.py │ ├── evaluate_benchmark_3_context.py │ ├── evaluate_benchmark_4_temporal.py │ ├── evaluate_benchmark_5_consistency.py │ ├── infer_consistency.py │ ├── infer_general.py │ └── scripts │ │ ├── gpt_eval.sh │ │ ├── infer_consistency.sh │ │ ├── infer_general.sh │ │ ├── pipeline_consistency.sh │ │ ├── pipeline_context.sh │ │ ├── pipeline_correctness.sh │ │ ├── pipeline_detail.sh │ │ └── pipeline_temporal.sh └── zeroshotqa │ ├── gpt_eval.py │ ├── qa_infer.py │ └── scripts │ ├── zeroshotqa_eval.sh │ ├── zeroshotqa_infer.sh │ └── zeroshotqa_pipeline.sh ├── LICENSE ├── PREPARE_DATASET.md ├── README.md ├── RLAIF ├── README.md ├── data_utils │ ├── common_utils.py │ ├── constants.py │ ├── data_utils_ppo.py │ ├── data_utils_rm.py │ └── data_utils_sft.py ├── finetune_lora_ppo.py ├── finetune_lora_rm.py ├── finetune_policy_init.py ├── lora_utils.py ├── models │ ├── distributed_utils.py │ ├── ppo_trainer.py │ ├── qlora_model.py │ ├── reward_model.py │ ├── rl_models.py │ ├── rl_trainer.py │ └── trainer_utils.py ├── prompts │ └── fact_rlaif_reward_prompt_video.txt └── scripts │ ├── initialize_policy_model.sh │ ├── parse_largest_ckptname.sh │ ├── train_reward_model.sh │ ├── train_rl_model.sh │ └── zero2.json ├── RLAIF_DataGen └── README.md ├── SFT └── README.md ├── assets └── images │ └── rlaif_feedback_teaser.png ├── llava_setup ├── .gitignore ├── README.md └── fix_llava_padding.patch └── serve ├── __init__.py ├── cli.py ├── controller.py ├── gradio_utils.py ├── gradio_web_server.py ├── model_worker.py ├── register_worker.py ├── test_message.py └── utils.py /Evaluation/combine_preds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tqdm 3 | import os 4 | import json 5 | 6 | 7 | def load_json(fpath): 8 | with open(fpath, "r") as f: 9 | return json.load(f) 10 | 11 | 12 | def save_json(data, fpath): 13 | with open(fpath, "w") as f: 14 | json.dump(data, f) 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--pred_dir") 19 | parser.add_argument("--infer_fname", default="infer_all.json") 20 | args = parser.parse_args() 21 | 22 | save_new=True 23 | print(args.pred_dir) 24 | # for possible_out_name in ["infer_all.json", "msvd_qa_infer_all.json", "anet_qa_infer_all.json", "generic.json"]: 25 | for possible_out_name in [args.infer_fname]: 26 | if os.path.exists(os.path.join(args.pred_dir, possible_out_name)): 27 | print("Already exists", os.path.join(args.pred_dir, possible_out_name)) 28 | save_new=False 29 | 30 | if save_new: 31 | all_preds = [] 32 | for fname in os.listdir(args.pred_dir): 33 | if fname.endswith(".json"): 34 | if 'gpt' in fname: continue 35 | all_preds += load_json(os.path.join(args.pred_dir, fname)) 36 | save_json(all_preds, os.path.join(args.pred_dir, args.infer_fname)) 37 | print("Done") -------------------------------------------------------------------------------- /Evaluation/infer_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import numpy as np 5 | from PIL import Image 6 | import requests 7 | from io import BytesIO 8 | import torch 9 | from torchvision.transforms import Compose, Lambda, ToTensor 10 | from torchvision.transforms.functional import to_pil_image 11 | 12 | 13 | def load_json(file_path): 14 | with open(file_path, 'r') as f: 15 | return json.load(f) 16 | 17 | def load_jsonl(file_path): 18 | with open(file_path, 'r') as f: 19 | return [json.loads(l) for l in f] 20 | 21 | def save_json(data, file_path): 22 | with open(file_path, 'w') as f: 23 | json.dump(data, f) 24 | 25 | def save_jsonl(data, file_path): 26 | with open(file_path, 'w') as f: 27 | for d in data: 28 | f.write(json.dumps(d) + '\n') 29 | 30 | def split_list(lst, n): 31 | """Split a list into n (roughly) equal-sized chunks""" 32 | chunk_size = math.ceil(len(lst) / n) # integer division 33 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 34 | 35 | 36 | def get_chunk(lst, n, k): 37 | chunks = split_list(lst, n) 38 | return chunks[k] 39 | 40 | 41 | def load_image(image_file): 42 | if image_file.startswith('http://') or image_file.startswith('https://'): 43 | response = requests.get(image_file) 44 | image = Image.open(BytesIO(response.content)).convert('RGB') 45 | else: 46 | image = Image.open(image_file).convert('RGB') 47 | return image 48 | 49 | 50 | def load_frames(frame_names, num_frames=None): 51 | frame_names.sort() 52 | # sample frames 53 | if num_frames is not None and len(frame_names) != num_frames: 54 | duration = len(frame_names) 55 | frame_id_array = np.linspace(0, duration-1, num_frames, dtype=int) 56 | frame_id_list = frame_id_array.tolist() 57 | else: 58 | frame_id_list = range(num_frames) 59 | 60 | results = [] 61 | for frame_idx in frame_id_list: 62 | frame_name = frame_names[frame_idx] 63 | results.append(load_image(frame_name)) 64 | 65 | return results 66 | 67 | 68 | def load_video_into_frames( 69 | video_path, 70 | video_decode_backend='opencv', 71 | num_frames=8, 72 | return_tensor=False, 73 | ): 74 | print("VIDEO PATH !!!", video_path) 75 | if video_decode_backend == 'decord': 76 | import decord 77 | from decord import VideoReader, cpu 78 | decord.bridge.set_bridge('torch') 79 | decord_vr = VideoReader(video_path, ctx=cpu(0)) 80 | duration = len(decord_vr) 81 | frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) 82 | video_data = decord_vr.get_batch(frame_id_list) 83 | if return_tensor: 84 | video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) 85 | else: 86 | video_data = [to_pil_image(f) for f in video_data] 87 | elif video_decode_backend == 'frames': 88 | frames = load_frames([os.path.join(video_path, imname) 89 | for imname in os.listdir(video_path)], 90 | num_frames=num_frames) 91 | video_data = frames 92 | if return_tensor: 93 | to_tensor = ToTensor() 94 | video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W) 95 | elif video_decode_backend == 'opencv': 96 | import cv2 97 | cv2_vr = cv2.VideoCapture(video_path) 98 | duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) 99 | frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) 100 | # frame_id_list = np.linspace(0, duration-5, num_frames, dtype=int) 101 | 102 | video_data = [] 103 | for frame_idx in frame_id_list: 104 | cv2_vr.set(1, frame_idx) 105 | ret, frame = cv2_vr.read() 106 | if not ret: 107 | raise ValueError(f'video error at {video_path}') 108 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 109 | if return_tensor: 110 | video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) 111 | else: 112 | video_data.append(Image.fromarray(frame)) 113 | cv2_vr.release() 114 | if return_tensor: 115 | video_data = torch.stack(video_data, dim=1) 116 | else: 117 | raise NameError(f'video_decode_backend should specify in (pytorchvideo, decord, opencv, frames) but got {video_decode_backend}') 118 | return video_data 119 | -------------------------------------------------------------------------------- /Evaluation/llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /Evaluation/llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /Evaluation/llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /Evaluation/llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /Evaluation/llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True if args.temperature > 0 else False, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/model_vqa_loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | # Custom dataset class 31 | class CustomDataset(Dataset): 32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): 33 | self.questions = questions 34 | self.image_folder = image_folder 35 | self.tokenizer = tokenizer 36 | self.image_processor = image_processor 37 | self.model_config = model_config 38 | 39 | def __getitem__(self, index): 40 | line = self.questions[index] 41 | image_file = line["image"] 42 | qs = line["text"] 43 | if self.model_config.mm_use_im_start_end: 44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 45 | else: 46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | conv.append_message(conv.roles[0], qs) 50 | conv.append_message(conv.roles[1], None) 51 | prompt = conv.get_prompt() 52 | 53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') 54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0] 55 | 56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') 57 | 58 | return input_ids, image_tensor 59 | 60 | def __len__(self): 61 | return len(self.questions) 62 | 63 | 64 | # DataLoader 65 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): 66 | assert batch_size == 1, "batch_size must be 1" 67 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config) 68 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False) 69 | return data_loader 70 | 71 | 72 | def eval_model(args): 73 | # Model 74 | disable_torch_init() 75 | model_path = os.path.expanduser(args.model_path) 76 | model_name = get_model_name_from_path(model_path) 77 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 78 | 79 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 80 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 81 | answers_file = os.path.expanduser(args.answers_file) 82 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 83 | ans_file = open(answers_file, "w") 84 | 85 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 86 | args.conv_mode = args.conv_mode + '_mmtag' 87 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 88 | 89 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config) 90 | 91 | for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)): 92 | idx = line["question_id"] 93 | cur_prompt = line["text"] 94 | 95 | stop_str = conv_templates[args.conv_mode].sep if conv_templates[args.conv_mode].sep_style != SeparatorStyle.TWO else conv_templates[args.conv_mode].sep2 96 | input_ids = input_ids.to(device='cuda', non_blocking=True) 97 | 98 | with torch.inference_mode(): 99 | output_ids = model.generate( 100 | input_ids, 101 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 102 | do_sample=True if args.temperature > 0 else False, 103 | temperature=args.temperature, 104 | top_p=args.top_p, 105 | num_beams=args.num_beams, 106 | max_new_tokens=128, 107 | use_cache=True) 108 | 109 | input_token_len = input_ids.shape[1] 110 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 111 | if n_diff_input_output > 0: 112 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 113 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 114 | outputs = outputs.strip() 115 | if outputs.endswith(stop_str): 116 | outputs = outputs[:-len(stop_str)] 117 | outputs = outputs.strip() 118 | 119 | ans_id = shortuuid.uuid() 120 | ans_file.write(json.dumps({"question_id": idx, 121 | "prompt": cur_prompt, 122 | "text": outputs, 123 | "answer_id": ans_id, 124 | "model_id": model_name, 125 | "metadata": {}}) + "\n") 126 | # ans_file.flush() 127 | ans_file.close() 128 | 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser() 131 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 132 | parser.add_argument("--model-base", type=str, default=None) 133 | parser.add_argument("--image-folder", type=str, default="") 134 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 135 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 136 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 137 | parser.add_argument("--num-chunks", type=int, default=1) 138 | parser.add_argument("--chunk-idx", type=int, default=0) 139 | parser.add_argument("--temperature", type=float, default=0.2) 140 | parser.add_argument("--top_p", type=float, default=None) 141 | parser.add_argument("--num_beams", type=int, default=1) 142 | args = parser.parse_args() 143 | 144 | eval_model(args) 145 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/model_vqa_mmbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | all_options = ['A', 'B', 'C', 'D'] 20 | 21 | 22 | def split_list(lst, n): 23 | """Split a list into n (roughly) equal-sized chunks""" 24 | chunk_size = math.ceil(len(lst) / n) # integer division 25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 26 | 27 | 28 | def get_chunk(lst, n, k): 29 | chunks = split_list(lst, n) 30 | return chunks[k] 31 | 32 | 33 | def is_none(value): 34 | if value is None: 35 | return True 36 | if type(value) is float and math.isnan(value): 37 | return True 38 | if type(value) is str and value.lower() == 'nan': 39 | return True 40 | if type(value) is str and value.lower() == 'none': 41 | return True 42 | return False 43 | 44 | def get_options(row, options): 45 | parsed_options = [] 46 | for option in options: 47 | option_value = row[option] 48 | if is_none(option_value): 49 | break 50 | parsed_options.append(option_value) 51 | return parsed_options 52 | 53 | 54 | def eval_model(args): 55 | # Model 56 | disable_torch_init() 57 | model_path = os.path.expanduser(args.model_path) 58 | model_name = get_model_name_from_path(model_path) 59 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 60 | 61 | questions = pd.read_table(os.path.expanduser(args.question_file)) 62 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 63 | answers_file = os.path.expanduser(args.answers_file) 64 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 65 | ans_file = open(answers_file, "w") 66 | 67 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 68 | args.conv_mode = args.conv_mode + '_mmtag' 69 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 70 | 71 | for index, row in tqdm(questions.iterrows(), total=len(questions)): 72 | options = get_options(row, all_options) 73 | cur_option_char = all_options[:len(options)] 74 | 75 | if args.all_rounds: 76 | num_rounds = len(options) 77 | else: 78 | num_rounds = 1 79 | 80 | for round_idx in range(num_rounds): 81 | idx = row['index'] 82 | question = row['question'] 83 | hint = row['hint'] 84 | image = load_image_from_base64(row['image']) 85 | if not is_none(hint): 86 | question = hint + '\n' + question 87 | for option_char, option in zip(all_options[:len(options)], options): 88 | question = question + '\n' + option_char + '. ' + option 89 | qs = cur_prompt = question 90 | if model.config.mm_use_im_start_end: 91 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 94 | 95 | if args.single_pred_prompt: 96 | if args.lang == 'cn': 97 | qs = qs + '\n' + "请直接回答选项字母。" 98 | else: 99 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 100 | 101 | conv = conv_templates[args.conv_mode].copy() 102 | conv.append_message(conv.roles[0], qs) 103 | conv.append_message(conv.roles[1], None) 104 | prompt = conv.get_prompt() 105 | 106 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 107 | 108 | image_tensor = process_images([image], image_processor, model.config)[0] 109 | # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 110 | 111 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 112 | 113 | with torch.inference_mode(): 114 | output_ids = model.generate( 115 | input_ids, 116 | images=image_tensor.unsqueeze(0).half().cuda(), 117 | do_sample=True if args.temperature > 0 else False, 118 | temperature=args.temperature, 119 | top_p=args.top_p, 120 | num_beams=args.num_beams, 121 | # no_repeat_ngram_size=3, 122 | max_new_tokens=1024, 123 | use_cache=True) 124 | 125 | input_token_len = input_ids.shape[1] 126 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 127 | if n_diff_input_output > 0: 128 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 129 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 130 | outputs = outputs.strip() 131 | if outputs.endswith(stop_str): 132 | outputs = outputs[:-len(stop_str)] 133 | outputs = outputs.strip() 134 | 135 | ans_id = shortuuid.uuid() 136 | ans_file.write(json.dumps({"question_id": idx, 137 | "round_id": round_idx, 138 | "prompt": cur_prompt, 139 | "text": outputs, 140 | "options": options, 141 | "option_char": cur_option_char, 142 | "answer_id": ans_id, 143 | "model_id": model_name, 144 | "metadata": {}}) + "\n") 145 | ans_file.flush() 146 | 147 | # rotate options 148 | options = options[1:] + options[:1] 149 | cur_option_char = cur_option_char[1:] + cur_option_char[:1] 150 | ans_file.close() 151 | 152 | if __name__ == "__main__": 153 | parser = argparse.ArgumentParser() 154 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 155 | parser.add_argument("--model-base", type=str, default=None) 156 | parser.add_argument("--image-folder", type=str, default="") 157 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 158 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 159 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 160 | parser.add_argument("--num-chunks", type=int, default=1) 161 | parser.add_argument("--chunk-idx", type=int, default=0) 162 | parser.add_argument("--temperature", type=float, default=0.2) 163 | parser.add_argument("--top_p", type=float, default=None) 164 | parser.add_argument("--num_beams", type=int, default=1) 165 | parser.add_argument("--all-rounds", action="store_true") 166 | parser.add_argument("--single-pred-prompt", action="store_true") 167 | parser.add_argument("--lang", type=str, default="en") 168 | args = parser.parse_args() 169 | 170 | eval_model(args) 171 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | qs = question['value'].replace('', '').strip() 45 | cur_prompt = qs 46 | 47 | if 'image' in line: 48 | image_file = line["image"] 49 | image = Image.open(os.path.join(args.image_folder, image_file)) 50 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 51 | images = image_tensor.unsqueeze(0).half().cuda() 52 | if getattr(model.config, 'mm_use_im_start_end', False): 53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 54 | else: 55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 56 | cur_prompt = '' + '\n' + cur_prompt 57 | else: 58 | images = None 59 | 60 | if args.single_pred_prompt: 61 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 62 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." 63 | 64 | conv = conv_templates[args.conv_mode].copy() 65 | conv.append_message(conv.roles[0], qs) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 70 | 71 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 72 | keywords = [stop_str] 73 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None 74 | 75 | with torch.inference_mode(): 76 | output_ids = model.generate( 77 | input_ids, 78 | images=images, 79 | do_sample=True if args.temperature > 0 else False, 80 | temperature=args.temperature, 81 | max_new_tokens=1024, 82 | use_cache=True, 83 | stopping_criteria=stopping_criteria, 84 | ) 85 | 86 | input_token_len = input_ids.shape[1] 87 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 88 | if n_diff_input_output > 0: 89 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 90 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 91 | outputs = outputs.strip() 92 | if outputs.endswith(stop_str): 93 | outputs = outputs[:-len(stop_str)] 94 | outputs = outputs.strip() 95 | 96 | # prompt for answer 97 | if args.answer_prompter: 98 | outputs_reasoning = outputs 99 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 100 | 101 | with torch.inference_mode(): 102 | output_ids = model.generate( 103 | input_ids, 104 | images=images, 105 | do_sample=True if args.temperature > 0 else False, 106 | temperature=args.temperature, 107 | max_new_tokens=64, 108 | use_cache=True, 109 | stopping_criteria=[stopping_criteria]) 110 | 111 | input_token_len = input_ids.shape[1] 112 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 113 | if n_diff_input_output > 0: 114 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 115 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 116 | outputs = outputs.strip() 117 | if outputs.endswith(stop_str): 118 | outputs = outputs[:-len(stop_str)] 119 | outputs = outputs.strip() 120 | outputs = outputs_reasoning + '\n The answer is ' + outputs 121 | 122 | ans_id = shortuuid.uuid() 123 | ans_file.write(json.dumps({"question_id": idx, 124 | "prompt": cur_prompt, 125 | "text": outputs, 126 | "answer_id": ans_id, 127 | "model_id": model_name, 128 | "metadata": {}}) + "\n") 129 | ans_file.flush() 130 | ans_file.close() 131 | 132 | if __name__ == "__main__": 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 135 | parser.add_argument("--model-base", type=str, default=None) 136 | parser.add_argument("--image-folder", type=str, default="") 137 | parser.add_argument("--question-file", type=str, default="tables/question.json") 138 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 139 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 140 | parser.add_argument("--num-chunks", type=int, default=1) 141 | parser.add_argument("--chunk-idx", type=int, default=0) 142 | parser.add_argument("--temperature", type=float, default=0.2) 143 | parser.add_argument("--answer-prompter", action="store_true") 144 | parser.add_argument("--single-pred-prompt", action="store_true") 145 | args = parser.parse_args() 146 | 147 | eval_model(args) 148 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if 'llama-2' in model_name.lower(): 40 | conv_mode = "llava_llama_2" 41 | elif "v1" in model_name.lower(): 42 | conv_mode = "llava_v1" 43 | elif "mpt" in model_name.lower(): 44 | conv_mode = "mpt" 45 | else: 46 | conv_mode = "llava_v0" 47 | 48 | if args.conv_mode is not None and conv_mode != args.conv_mode: 49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 50 | else: 51 | args.conv_mode = conv_mode 52 | 53 | conv = conv_templates[args.conv_mode].copy() 54 | conv.append_message(conv.roles[0], qs) 55 | conv.append_message(conv.roles[1], None) 56 | prompt = conv.get_prompt() 57 | 58 | image = load_image(args.image_file) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 60 | 61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 62 | 63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 64 | keywords = [stop_str] 65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor, 71 | do_sample=True, 72 | temperature=0.2, 73 | max_new_tokens=1024, 74 | use_cache=True, 75 | stopping_criteria=[stopping_criteria]) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | print(outputs) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 91 | parser.add_argument("--model-base", type=str, default=None) 92 | parser.add_argument("--image-file", type=str, required=True) 93 | parser.add_argument("--query", type=str, required=True) 94 | parser.add_argument("--conv-mode", type=str, default=None) 95 | args = parser.parse_args() 96 | 97 | eval_model(args) 98 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/model/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/model/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/model/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/model/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /Evaluation/llava/model/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): # HERE! 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 77 | outputs = self.model( 78 | input_ids=input_ids, 79 | attention_mask=attention_mask, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | use_cache=use_cache, 83 | output_attentions=output_attentions, 84 | output_hidden_states=output_hidden_states, 85 | return_dict=return_dict 86 | ) 87 | 88 | hidden_states = outputs[0] 89 | logits = self.lm_head(hidden_states) 90 | 91 | loss = None 92 | if labels is not None: 93 | # Shift so that tokens < n predict n 94 | shift_logits = logits[..., :-1, :].contiguous() 95 | shift_labels = labels[..., 1:].contiguous() 96 | # Flatten the tokens 97 | loss_fct = CrossEntropyLoss() 98 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 99 | shift_labels = shift_labels.view(-1) 100 | # Enable model/pipeline parallelism 101 | shift_labels = shift_labels.to(shift_logits.device) 102 | loss = loss_fct(shift_logits, shift_labels) 103 | 104 | if not return_dict: 105 | output = (logits,) + outputs[1:] 106 | return (loss,) + output if loss is not None else output 107 | 108 | return CausalLMOutputWithPast( 109 | loss=loss, 110 | logits=logits, 111 | past_key_values=outputs.past_key_values, 112 | hidden_states=outputs.hidden_states, 113 | attentions=outputs.attentions, 114 | ) 115 | 116 | def prepare_inputs_for_generation( 117 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 118 | ): 119 | if past_key_values: 120 | input_ids = input_ids[:, -1:] 121 | 122 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 123 | if inputs_embeds is not None and past_key_values is None: 124 | model_inputs = {"inputs_embeds": inputs_embeds} 125 | else: 126 | model_inputs = {"input_ids": input_ids} 127 | 128 | model_inputs.update( 129 | { 130 | "past_key_values": past_key_values, 131 | "use_cache": kwargs.get("use_cache"), 132 | "attention_mask": attention_mask, 133 | "images": kwargs.get("images", None), 134 | } 135 | ) 136 | return model_inputs 137 | 138 | AutoConfig.register("llava", LlavaConfig) 139 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 140 | -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /Evaluation/llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /Evaluation/llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /Evaluation/llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /Evaluation/llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /Evaluation/llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /Evaluation/llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /Evaluation/llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/serve/__init__.py -------------------------------------------------------------------------------- /Evaluation/llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | # Similar operation in model_worker.py 56 | image_tensor = process_images([image], image_processor, args) 57 | if type(image_tensor) is list: 58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 59 | else: 60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 61 | 62 | # ##### Debug 63 | # image = None 64 | 65 | while True: 66 | try: 67 | inp = input(f"{roles[0]}: ") 68 | except EOFError: 69 | inp = "" 70 | if not inp: 71 | print("exit...") 72 | break 73 | 74 | print(f"{roles[1]}: ", end="") 75 | 76 | if image is not None: 77 | # first message 78 | if model.config.mm_use_im_start_end: 79 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 80 | else: 81 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 82 | conv.append_message(conv.roles[0], inp) 83 | image = None 84 | else: 85 | # later messages 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompt = conv.get_prompt() 89 | 90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 94 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 95 | 96 | with torch.inference_mode(): 97 | output_ids = model.generate( 98 | input_ids, 99 | images=image_tensor, 100 | do_sample=True, 101 | temperature=args.temperature, 102 | max_new_tokens=args.max_new_tokens, 103 | streamer=streamer, 104 | use_cache=True, 105 | stopping_criteria=[stopping_criteria]) 106 | 107 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 108 | conv.messages[-1][-1] = outputs 109 | 110 | if args.debug: 111 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 117 | parser.add_argument("--model-base", type=str, default=None) 118 | parser.add_argument("--image-file", type=str, required=True) 119 | parser.add_argument("--device", type=str, default="cuda") 120 | parser.add_argument("--conv-mode", type=str, default=None) 121 | parser.add_argument("--temperature", type=float, default=0.2) 122 | parser.add_argument("--max-new-tokens", type=int, default=512) 123 | parser.add_argument("--load-8bit", action="store_true") 124 | parser.add_argument("--load-4bit", action="store_true") 125 | parser.add_argument("--debug", action="store_true") 126 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 127 | args = parser.parse_args() 128 | main(args) 129 | -------------------------------------------------------------------------------- /Evaluation/llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /Evaluation/llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/Evaluation/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /Evaluation/llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /Evaluation/llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /Evaluation/llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /Evaluation/llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | import sys 7 | sys.path.append('/dataset/llms/llava/LLaVA_Video_temp') 8 | # sys.path.append('/dataset/dcahn/llms/llava/LLaVA_Video_temp') 9 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 10 | 11 | replace_llama_attn_with_flash_attn() 12 | 13 | from llava.train.train import train 14 | 15 | if __name__ == "__main__": 16 | train() 17 | -------------------------------------------------------------------------------- /Evaluation/llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /Evaluation/scripts/videochatgpt_pipeline.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | DATA_PATH=playground/data 5 | 6 | # ================== CHANGE HERE ================== 7 | MODEL_PATH=SNUMPR/vlm_rlaif_video_llava_7b 8 | MODEL_BASE=none 9 | OUTPUT_DIR=results/vlm_rlaif_video_llava_7b 10 | export cache_dir=./cache_dir 11 | export API_KEY="YOUR OPENAI API KEY HERE" 12 | 13 | TASKNAMES=( temporal ) 14 | TASKNAMES=( temporal ) 15 | DATA_DIR=/dataset/dcahn/llms/YuraLLM/playground/data/VideoChatGPT_Eval/original_data 16 | FRAMES_PATH=/dataset/dcahn/llms/YuraLLM/playground/data/VideoChatGPT_Eval/Test_Videos 17 | # ================== CHANGE HERE ================== 18 | OUTPUT_DIR=$OUTPUT_DIR/videochatgpt 19 | 20 | for TASKNAME in ${TASKNAMES[@]}; do 21 | bash Evaluation/videochatgpt/scripts/pipeline_$TASKNAME.sh \ 22 | $MODEL_PATH \ 23 | $MODEL_BASE \ 24 | $OUTPUT_DIR \ 25 | $TASKNAME \ 26 | $DATA_DIR \ 27 | $FRAMES_PATH 28 | wait 29 | done -------------------------------------------------------------------------------- /Evaluation/scripts/zeroshotqa_pipeline.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | export cache_dir="cache_dir" 5 | 6 | # ================== CHANGE HERE ================== 7 | MODEL_PATH=SNUMPR/vlm_rlaif_video_llava_7b 8 | MODEL_BASE=none 9 | OUTPUT_DIR=results/vlm_rlaif_video_llava_7b 10 | ANNOT_PATH=playground/data/eval_dataset/zeroshotqa/annotations 11 | FRAMES_PATH="playground/data/eval_dataset/zeroshotqa/video_frames" 12 | export API_KEY="YOUR OPENAI API KEY HERE" 13 | 14 | TASKNAMES=( anet msrvtt msvd ) 15 | # ================== CHANGE HERE ================== 16 | 17 | for TASKNAME in ${TASKNAMES[@]}; do 18 | bash Evaluation/zeroshotqa/scripts/zeroshotqa_infer.sh \ 19 | $MODEL_PATH \ 20 | $MODEL_BASE \ 21 | $OUTPUT_DIR/zeroshotqa \ 22 | $TASKNAME \ 23 | $ANNOT_PATH/$TASKNAME"_qa.json" \ 24 | $FRAMES_PATH/$TASKNAME \ 25 | $CHUNKS 26 | wait 27 | bash Evaluation/zeroshotqa/scripts/zeroshotqa_eval.sh \ 28 | $OUTPUT_DIR/zeroshotqa \ 29 | $TASKNAME 30 | wait 31 | done 32 | -------------------------------------------------------------------------------- /Evaluation/videochatgpt/evaluate_benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Define common arguments for all scripts 4 | PRED_DIR="" 5 | OUTPUT_DIR="" 6 | API_KEY="" 7 | NUM_TASKS="" 8 | 9 | # Run the "correctness" evaluation script 10 | python evaluate_benchmark_1_correctness.py \ 11 | --pred_path "${PRED_DIR}/correctness_pred" \ 12 | --output_dir "${OUTPUT_DIR}/correctness_eval" \ 13 | --output_json "${OUTPUT_DIR}/correctness_results.json" \ 14 | --api_key $API_KEY \ 15 | --num_tasks $NUM_TASKS 16 | 17 | # Run the "detailed orientation" evaluation script 18 | python evaluate_benchmark_2_detailed_orientation.py \ 19 | --pred_path "${PRED_DIR}/detailed_orientation_pred" \ 20 | --output_dir "${OUTPUT_DIR}/detailed_eval" \ 21 | --output_json "${OUTPUT_DIR}/detailed_orientation_results.json" \ 22 | --api_key $API_KEY \ 23 | --num_tasks $NUM_TASKS 24 | 25 | # Run the "contextual understanding" evaluation script 26 | python evaluate_benchmark_3_context.py \ 27 | --pred_path "${PRED_DIR}/contextual_pred" \ 28 | --output_dir "${OUTPUT_DIR}/context_eval" \ 29 | --output_json "${OUTPUT_DIR}/contextual_understanding_results.json" \ 30 | --api_key $API_KEY \ 31 | --num_tasks $NUM_TASKS 32 | 33 | # Run the "temporal understanding" evaluation script 34 | python evaluate_benchmark_4_temporal.py \ 35 | --pred_path "${PRED_DIR}/temporal_understanding_pred" \ 36 | --output_dir "${OUTPUT_DIR}/temporal_eval" \ 37 | --output_json "${OUTPUT_DIR}/temporal_understanding_results.json" \ 38 | --api_key $API_KEY \ 39 | --num_tasks $NUM_TASKS 40 | 41 | # Run the "consistency" evaluation script 42 | python evaluate_benchmark_5_consistency.py \ 43 | --pred_path "${PRED_DIR}/consistency_pred" \ 44 | --output_dir "${OUTPUT_DIR}/consistency_eval" \ 45 | --output_json "${OUTPUT_DIR}/consistency_results.json" \ 46 | --api_key $API_KEY \ 47 | --num_tasks $NUM_TASKS 48 | 49 | 50 | echo "All evaluations completed!" 51 | -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/gpt_eval.sh: -------------------------------------------------------------------------------- 1 | NUM_TASKS=10 2 | # API_KEY="OPENAI KEY HERE" 3 | source Evaluation/.api_keys/key1 4 | 5 | OUTPUT_DIR=$1 6 | TASKID=$2 7 | 8 | TASKNAMES=( correctness detailed_orientation context temporal consistency ) 9 | INFERNAMES=( generic generic generic temporal consistency ) 10 | TASKNAME="${TASKNAMES[$TASKID-1]}" 11 | INFERFNAME="${INFERNAMES[$TASKID-1]}" 12 | 13 | PRED_DIR=$OUTPUT_DIR/$INFERFNAME 14 | PRED_PATH=$PRED_DIR/$INFERFNAME".json" 15 | 16 | OUT_JSON=$OUTPUT_DIR/$INFERFNAME/gpt_$TASKNAME".json" 17 | OUTPUT_DIR=$OUTPUT_DIR/$INFERFNAME/gpt_eval/$TASKNAME 18 | 19 | if [ $INFERFNAME=generic ]; then 20 | python3 scripts/eval_script/combine_preds.py --pred_dir $PRED_DIR --infer_fname $INFERFNAME.json 21 | fi 22 | 23 | echo $PRED_PATH $OUTPUT_DIR 24 | python3 Evaluation/videochatgpt/evaluate_benchmark_${TASKID}"_"$TASKNAME.py \ 25 | --pred_path $PRED_PATH \ 26 | --output_dir $OUTPUT_DIR \ 27 | --output_json $OUT_JSON \ 28 | --api_key $API_KEY \ 29 | --num_tasks $NUM_TASKS 30 | -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/infer_consistency.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | 5 | MODEL_PATH=$1 6 | MODEL_BASE=$2 7 | OUTPUT_DIR=$3 8 | TASKNAME=${4:-consistency} 9 | VIDEOCHATGPT_EVAL_PATH=$5 10 | FRAMES_PATH=$6 11 | OUTPUT_DIR=$OUTPUT_DIR/$TASKNAME 12 | 13 | GPU_IDS=( 0 1 2 3 4 5 6 7 ) 14 | SPLITS=( 0 1 2 3 4 5 6 7 ) 15 | N_SPLIT=${#GPU_IDS[@]} 16 | 17 | for DEVICE_ID in ${GPU_IDS[@]}; do 18 | CUDA_VISIBLE_DEVICES=$DEVICE_ID \ 19 | python3 evaluate/video_chatgpt/run_inference_benchmark_consistency.py \ 20 | --model-path $MODEL_PATH \ 21 | --model-base $MODEL_BASE \ 22 | --frames_path $FRAMES_PATH \ 23 | --gt_file $VIDEOCHATGPT_EVAL_PATH/$TASKNAME"_qa.json" \ 24 | --output_dir $OUTPUT_DIR \ 25 | --output_name $N_SPLIT"_${SPLITS[$DEVICE_ID]}" \ 26 | --images \ 27 | --num_frames 50 \ 28 | --rlhf_ckpt \ 29 | --chunks $N_SPLIT \ 30 | --chunk_idx ${SPLITS[$DEVICE_ID]} \ 31 | --resume \ 32 | & 33 | done 34 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/infer_general.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | 5 | MODEL_PATH=$1 6 | MODEL_BASE=$2 7 | OUTPUT_DIR=$3 8 | TASKNAME=$4 9 | VIDEOCHATGPT_EVAL_PATH=$5 10 | FRAMES_PATH=$6 11 | OUTPUT_DIR=$OUTPUT_DIR/$TASKNAME 12 | 13 | GPU_IDS=( 0 1 2 3 4 5 6 7 ) 14 | SPLITS=( 0 1 2 3 4 5 6 7 ) 15 | N_SPLIT=${#GPU_IDS[@]} 16 | 17 | for DEVICE_ID in ${GPU_IDS[@]}; do 18 | CUDA_VISIBLE_DEVICES=$DEVICE_ID \ 19 | python3 Evaluation/videochatgpt/infer_general.py \ 20 | --model-path $MODEL_PATH \ 21 | --model-base $MODEL_BASE \ 22 | --frames_path $FRAMES_PATH \ 23 | --gt_file $VIDEOCHATGPT_EVAL_PATH/$TASKNAME"_qa.json" \ 24 | --output_dir $OUTPUT_DIR \ 25 | --output_name $N_SPLIT"_${SPLITS[$DEVICE_ID]}" \ 26 | --images \ 27 | --num_frames 50 \ 28 | --rlhf_ckpt \ 29 | --chunks $N_SPLIT \ 30 | --chunk_idx ${SPLITS[$DEVICE_ID]} \ 31 | --resume \ 32 | & 33 | done 34 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/pipeline_consistency.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=$1 2 | MODEL_BASE=$2 3 | OUTPUT_DIR=$3 4 | TASKNAME=$4 5 | DATA_DIR=$5 6 | FRAMES_PATH=$6 7 | 8 | bash Evaluation/videochatgpt/scripts/infer_consistency.sh \ 9 | $MODEL_PATH \ 10 | $MODEL_BASE \ 11 | $OUTPUT_DIR \ 12 | $TASKNAME \ 13 | $DATA_DIR \ 14 | $FRAMES_PATH 15 | wait 16 | 17 | bash Evaluation/videochatgpt/scripts/gpt_eval.sh $OUTPUT_DIR 5 18 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/pipeline_context.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=$1 2 | MODEL_BASE=$2 3 | OUTPUT_DIR=$3 4 | TASKNAME=$4 5 | DATA_DIR=$5 6 | FRAMES_PATH=$6 7 | 8 | # Generic inference 9 | bash Evaluation/videochatgpt/scripts/infer_general.sh \ 10 | $MODEL_PATH \ 11 | $MODEL_BASE \ 12 | $OUTPUT_DIR \ 13 | $TASKNAME \ 14 | $DATA_DIR \ 15 | $FRAMES_PATH 16 | wait 17 | 18 | bash Evaluation/videochatgpt/scripts/gpt_eval.sh $OUTPUT_DIR 3 19 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/pipeline_correctness.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=$1 2 | MODEL_BASE=$2 3 | OUTPUT_DIR=$3 4 | TASKNAME=$4 5 | DATA_DIR=$5 6 | FRAMES_PATH=$6 7 | 8 | # Generic inference 9 | bash Evaluation/videochatgpt/scripts/infer_general.sh \ 10 | $MODEL_PATH \ 11 | $MODEL_BASE \ 12 | $OUTPUT_DIR \ 13 | $TASKNAME \ 14 | $DATA_DIR \ 15 | $FRAMES_PATH 16 | wait 17 | 18 | bash Evaluation/videochatgpt/scripts/gpt_eval.sh $OUTPUT_DIR 1 19 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/pipeline_detail.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=$1 2 | MODEL_BASE=$2 3 | OUTPUT_DIR=$3 4 | TASKNAME=$4 5 | DATA_DIR=$5 6 | FRAMES_PATH=$6 7 | 8 | # Generic inference 9 | bash Evaluation/videochatgpt/scripts/infer_general.sh \ 10 | $MODEL_PATH \ 11 | $MODEL_BASE \ 12 | $OUTPUT_DIR \ 13 | $TASKNAME \ 14 | $DATA_DIR \ 15 | $FRAMES_PATH 16 | wait 17 | 18 | bash Evaluation/videochatgpt/scripts/gpt_eval.sh $OUTPUT_DIR 2 19 | wait -------------------------------------------------------------------------------- /Evaluation/videochatgpt/scripts/pipeline_temporal.sh: -------------------------------------------------------------------------------- 1 | MODEL_PATH=$1 2 | MODEL_BASE=$2 3 | OUTPUT_DIR=$3 4 | TASKNAME=$4 5 | DATA_DIR=$5 6 | FRAMES_PATH=$6 7 | 8 | # Generic inference 9 | bash Evaluation/videochatgpt/scripts/infer_general.sh \ 10 | $MODEL_PATH \ 11 | $MODEL_BASE \ 12 | $OUTPUT_DIR \ 13 | $TASKNAME \ 14 | $DATA_DIR \ 15 | $FRAMES_PATH 16 | wait 17 | 18 | bash Evaluation/videochatgpt/scripts/gpt_eval.sh $OUTPUT_DIR 4 19 | wait -------------------------------------------------------------------------------- /Evaluation/zeroshotqa/scripts/zeroshotqa_eval.sh: -------------------------------------------------------------------------------- 1 | NUM_TASKS=10 2 | 3 | PRED_DIR=$1 4 | TASKNAME=$2 5 | 6 | PRED_PATH=$PRED_DIR/$TASKNAME 7 | OUT_JSON=$TASKNAME".json" 8 | 9 | python3 Evaluation/combine_preds.py --pred_dir $PRED_PATH 10 | PRED_PATH=$PRED_PATH/infer_all.json 11 | 12 | echo $PRED_PATH 13 | python3 Evaluation/zeroshotqa/gpt_eval.py \ 14 | --pred_path $PRED_PATH \ 15 | --output_json $OUT_JSON \ 16 | --api_key $API_KEY \ 17 | --num_tasks $NUM_TASKS 18 | -------------------------------------------------------------------------------- /Evaluation/zeroshotqa/scripts/zeroshotqa_infer.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | 5 | # MODEL_NAME=$1 6 | # CKPT_NAME=$2 7 | # source scripts/model_paths/$MODEL_NAME 8 | # TASKNAME=anet 9 | MODEL_PATH=$1 10 | MODEL_BASE=$2 11 | OUTPUT_DIR=$3 12 | TASKNAME=$4 13 | ANNOT_PATH=$5 14 | FRAMES_PATH=$6 15 | 16 | PRED_DIR=$3/$TASKNAME 17 | mkdir -p $PRED_DIR 18 | 19 | # GPU_IDS=( 0 1 2 3 4 5 6 7 ) 20 | GPU_IDS=( 7 ) 21 | SPLITS=( 0 1 2 3 4 5 6 7 ) 22 | N_SPLIT=8 23 | 24 | for DEVICE_ID in ${GPU_IDS[@]}; do 25 | CUDA_VISIBLE_DEVICES=$DEVICE_ID python3 Evaluation/zeroshotqa/qa_infer.py \ 26 | --model-path $MODEL_PATH \ 27 | --model-base $MODEL_BASE \ 28 | --gt_file_qa $ANNOT_PATH \ 29 | --chunks $N_SPLIT \ 30 | --chunk_idx ${SPLITS[$DEVICE_ID]} \ 31 | --output_dir $PRED_DIR \ 32 | --output_name $N_SPLIT"_${SPLITS[$DEVICE_ID]}" \ 33 | --images \ 34 | --frames_path $FRAMES_PATH \ 35 | --num_frames 50 \ 36 | --resume \ 37 | & 38 | done 39 | wait 40 | 41 | -------------------------------------------------------------------------------- /Evaluation/zeroshotqa/scripts/zeroshotqa_pipeline.sh: -------------------------------------------------------------------------------- 1 | NEW_PYPTH=$PWD/../.. 2 | NEW_PYPTH=$(builtin cd $NEW_PYPTH; pwd) 3 | export PYTHONPATH=$PYTHONPATH:$NEW_PYPTH 4 | DATA_PATH=playground/data 5 | 6 | # ================== CHANGE HERE ================== 7 | MODEL_PATH=SNUMPR/vlm_rlaif_video_llava_7b 8 | MODEL_BASE=none 9 | OUTPUT_DIR=results/vlm_rlaif_video_llava_7b 10 | FRAMES_PATH="playground/data/video_frames" 11 | export API_KEY="YOUR OPENAI API KEY HERE" 12 | 13 | TASKNAMES=( anet msrvtt msvd tgif ) 14 | # ================== CHANGE HERE ================== 15 | 16 | for TASKNAME in ${TASKNAMES[@]}; do 17 | bash Evaluation/zeroshotqa/scripts/zeroshotqa_infer.sh \ 18 | $MODEL_PATH \ 19 | $MODEL_BASE \ 20 | $OUTPUT_DIR/zeroshotqa \ 21 | $TASKNAME \ 22 | $FRAMES_PATH/$TASKNAME \ 23 | $CHUNKS 24 | wait 25 | bash Evaluation/zeroshotqa/scripts/zeroshotqa_eval.sh \ 26 | $OUTPUT_DIR/zeroshotqa \ 27 | $TASKNAME 28 | wait 29 | done 30 | -------------------------------------------------------------------------------- /PREPARE_DATASET.md: -------------------------------------------------------------------------------- 1 | # Preparing Training & Evaluation Dataset 2 | ## 🗃️ Training Dataset 3 | - **Note** Our Dataset is built upon four sources of datasets. 4 | 1. [Video-ChatGPT Video Instruction Dataset](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/docs/train_video_chatgpt.md) 5 | - ActivityNet, WebVid videos 6 | - 100K instructions 7 | 2. [Video Localized Narratives Dataset](https://github.com/google/video-localized-narratives/blob/main/data_preparation.md) 8 | 3. [How2QA](https://github.com/ych133/How2R-and-How2QA) 9 | 4. [NextQA](https://doc-doc.github.io/docs/nextqa.html) 10 | 5. [WebVid](https://github.com/m-bain/webvid) 11 | 12 |   13 | - 📜 **Instructions**: Download all of our video instructions from 🤗 [SNUMPR/vlm_rlaif_datasets](https://huggingface.co/SNUMPR/vlm_rlaif_datasets) 14 | | Dataset Usage | Filename | Source of Videos | 15 | |----------|---------------|---------------| 16 | | SFT (short) | SFT_short.json | All | 17 | | SFT (long) | SFT_long.json | All | 18 | | Preference dataset (RM) | RM_13b_v1_dataset_39k.json | ANet | 19 | | PPO init | PPO_init.json | ANet | 20 | | RLAIF | RL_data.json | ANet | 21 | 22 |   23 | 24 | - 🎥 **Videos**: Download source videos following the instructions below, and then extract 50 frames per each video to train the model. 25 | 26 | 1. **Video-ChatGPT Instruction Dataset - ActivityNet videos**: 27 | - **Frames** (🤗 [SNUMPR/vlm_rlaif_train_anet_frames](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_train_anet_frames)): Our version of preprocessed videos, extracted 50 frames per each video 28 | - [Videos](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/hanoona_bangalath_mbzuai_ac_ae/EatOpE7j68tLm2XAd0u6b8ABGGdVAwLMN6rqlDGM_DwhVA?e=90WIuW): video as mp4 format from original paper 29 | 2. Video Localized Narratives Dataset 30 | - See [download instructions](https://github.com/google/video-localized-narratives/blob/main/data_preparation.md) for the original dataset 31 | - Download source video from four datasets; OoPs, OVIS, kinetics400 (UVO), kinetics 32 | - Extract 50 frames per each video into `OOPs_50frames`, `OVIS_50frames`, `kinetics400_50frames`, `kinetics_50frames` 33 | 3. How2QA 34 | - See [download instructions](https://github.com/ych133/How2R-and-How2QA) to download videos 35 | - Extract 50 frames per each video into `how2qa_50frames` 36 | 4. NeXTQA 37 | - Download [Google Drive](https://drive.google.com/file/d/1jTcRCrVHS66ckOUfWRb-rXdzJ52XAWQH/view) link provided by original authors to download video files 38 | - Extract 50 frames per each video into `nextqa_50frames` 39 | 40 | 5. WebVid 41 | - Follow the official [WebVid dataset](https://github.com/m-bain/webvid) README to download the videos. 42 | - Extract 50 frames per each video into `webvid_50frames` 43 | 44 | 45 | ```Shell 46 | # 📁 Training data folder structure 47 | TRAIN_DATA_ROOT # (playground/data/train_dataset in default) 48 | ├── instructions 49 | └── videos 50 | ├── anet_vidchatgpt_50frames 51 | ├── OOPs_50frames 52 | ├── OVIS_50frames 53 | ├── kinetics400_50frames 54 | ├── kinetics_50frames 55 | ├── how2qa_50frames 56 | ├── nextqa_50frames 57 | └── webvid_50frames 58 | ``` 59 | 60 | ```plain text 61 | // Example structure 62 | { 63 | 'id': 'sampleid', 64 | 'src_data': 'original data source', 65 | 'conversations': [ 66 | {'role': 'human', 'value': ''}, 67 | {'role': 'gpt', 'value': ''} 68 | ] 69 | 'images': [ 70 | 'video_dir/image_01.jpg', 71 | 'video_dir/image_02.jpg', 72 | ... 73 | ] 74 | } 75 | 76 | ``` 77 |   78 |   79 | 80 | ## 🗃️ Evaluation Dataset 81 | 82 | 83 | ```Shell 84 | # 📁 Evaluation folder structure 85 | EVAL_DATA_ROOT # (playground/data/eval_dataset in default) 86 | ├── zeroshotqa 87 | │ ├── annotations 88 | │ └── frames 89 | │ ├── anet 90 | │ ├── msvd 91 | │ └── msrvtt 92 | └── videogenerativebench 93 | ├── annotations 94 | └── frames 95 | ``` 96 | ### Zero-shot QA 97 | 98 | - 🤗 [**SNUMPR/vlm_rlaif_eval_datasets**](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_eval_datasets/tree/main/zeroshotqa) Download our preprocessed zero-shot QA benchmark from this link. 99 | - For original videos and test split, follow instructions from [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA/blob/main/TRAIN_AND_VALIDATE.md) to download the Zero-shot QA dataset. 100 | 101 | 102 | ### Video Generative Benchmark 103 | - Download evaluation dataset & videos for zero-shot question answering from [Video-ChatGPT Qualitative Evaluation](https://github.com/mbzuai-oryx/Video-ChatGPT/blob/main/quantitative_evaluation/README.md). 104 | - [Videos](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/hanoona_bangalath_mbzuai_ac_ae/EatOpE7j68tLm2XAd0u6b8ABGGdVAwLMN6rqlDGM_DwhVA?e=90WIuW), [Descriptions](https://mbzuaiac-my.sharepoint.com/:u:/g/personal/hanoona_bangalath_mbzuai_ac_ae/EYqblLdszspJkayPvVIm5s0BCvl0m6q6B-ipmrNg-pqn6A?e=QFzc1U) 105 | - Extract 50 frames per each videos 106 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 VLM-RLAIF (ACL'24 Oral) 2 | > [**Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback**](https://dcahn12.github.io/projects/vlm-rlaif/), 3 | [Daechul Ahn](https://dcahn12.github.io)1,3, 4 | [Yura Choi](https://yuuraa.github.io)1,3, 5 | [Youngjae Yu](https://yj-yu.github.io/home/)1, 6 | [Dongyeop Kang](https://dykang.github.io)2, 7 | [Jonghyun Choi](https://ppolon.github.io)3,†
8 | 1Yonsei University, 9 | 2University of Minnesota, 10 | 3Seoul National University
11 | Corresponding Author
12 | [ACL 2024](https://2024.aclweb.org) (To appear) 13 | 14 | [![model-checkpoint](https://img.shields.io/badge/Model-RLAIF-blue)](https://huggingface.co/SNUMPR/vlm_rlaif_video_llava_7b) 15 | [![model-checkpoint-sft](https://img.shields.io/badge/Model-SFT-blue)](https://huggingface.co/SNUMPR/vlm_sft_video_llava_7b) 16 | [![paper](https://img.shields.io/badge/Paper-Arxiv-green)](https://arxiv.org/abs/2402.03746) 17 | 18 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/tuning-large-multimodal-models-for-videos/video-based-generative-performance)](https://paperswithcode.com/sota/video-based-generative-performance?p=tuning-large-multimodal-models-for-videos) 19 | 20 |   21 | 22 | ## 📣 News 23 | - [Aug 07, 2024] We update our trained lora checkpoint of reward model & policy model initialization to Hugginface 24 | - [Aug 06, 2024] Our model is available in HuggingFace Spaces! 25 | - [Jul 16, 2024] 🎙️ **VLM-RLAIF** has been selected for ✨***oral presentation***✨ at **ACL 2024**! See you in Bangkok 🇹🇭 26 | - [Jun 16, 2024] 🔥 Our next work on aligning large video multimodal model, **i-SRT**🚄, is now available 27 | [[arXiv](https://arxiv.org/pdf/2406.11280v1), [code](https://github.com/snumprlab/srt)] 28 | - [May 31, 2024] 🥳 **VLM-RLAIF** is accepted to **ACL 2024** ! 29 | 30 |   31 | 32 | 33 | ## 👀 Overview 34 | 38 | > **Abstract:** *Recent advancements in large language models have influenced the development of video large multimodal models (VLMMs). Previous approaches for VLMMs involve Supervised Fine-Tuning (SFT) with instruction-tuned datasets, integrating LLM with visual encoders, and additional learnable parameters. Here, aligning video with text, and vice versa, remains a challenge, primarily due to the insufficient quality and quantity of multimodal instruction-tune data compared to that of text-only. This discrepancy often results in alignments that poorly ground the video content. To address this, we present a novel alignment strategy that employs a multimodal AI system equipped with Reinforcement Learning from AI Feedback (RLAIF), providing self-preference feedback to refine itself and facilitating the alignment of video and text modalities. Our approach uniquely integrates detailed video descriptions as context into a multimodal AI system during preference feedback generation to enrich the understanding of video content, a process we call context-aware reward modeling. Empirical evaluations on various video benchmarks demonstrate that our VLM-RLAIF outperforms existing approaches, including the SFT model.* 39 | 40 |
41 | 42 |

Pipeline of VLM-RLAIF

43 |
44 | 45 | 46 | ## 🗃️ Dataset and Checkpoints 47 | > Check [PREPARE_DATASET.md](./PREPARE_DATASET.md) to prepare training & validation datasets 48 | 49 | | Model | Size | Checkpoint | corr. | detail. | context | temp. | const. | 50 | |----------|----------|-----------|---|---|---|---|---| 51 | | RLAIF | 7B | [SNUMPR/vlm_rlaif_video_llava_7b](https://huggingface.co/SNUMPR/vlm_rlaif_video_llava_7b)| 3.63 | 3.25 | 4.00 | 3.23 | 3.32 | 52 | | SFT | 7B | [SNUMPR/vlm_sft_video_llava_7b](https://huggingface.co/SNUMPR/vlm_sft_video_llava_7b) | 2.79 | 2.82 | 3.37 | 2.28 | 2.49 | 53 | 54 | Lora Checkpoints (used to train the model w/ PPO) 55 | | Model | Size | Lora Checkpoint | 56 | |----------|----------|-----------| 57 | | Policy init | 7B | [SNUMPR/vlm_policy_init_7b_lora](https://huggingface.co/SNUMPR/vlm_policy_init_7b_lora) | 58 | | Reward model | 13B | [SNUMPR/vlm_rm_13b_lora](https://huggingface.co/SNUMPR/vlm_rm_13b_lora) | 59 | 60 |   61 | 62 | | Dataset Usage | Link | 63 | |----------|----------| 64 | | SFT (short) | [SNUMPR/vlm_rlaif_datasets/SFT_short.json](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_datasets/blob/main/SFT_short.json) | 65 | | SFT (long) | [SNUMPR/vlm_rlaif_datasets/SFT_long.json](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_datasets/blob/main/SFT_long.json) | 66 | | Preference dataset (for RM) | [SNUMPR/vlm_rlaif_datasets/RM_13b_v1_dataset_39k.json](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_datasets/blob/main/RM_13b_v1_dataset_39k.json) | 67 | | PPO init | [SNUMPR/vlm_rlaif_datasets/PPO_init.json](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_datasets/blob/main/PPO_init.json) | 68 | | RLAIF | [SNUMPR/vlm_rlaif_datasets/RL_data.json](https://huggingface.co/datasets/SNUMPR/vlm_rlaif_datasets/blob/main/RL_data.json) | 69 | 70 |   71 | 72 | 73 | ## 📊 Evaluation 74 | > Check [PREPARE_DATASET.md](./PREPARE_DATASET.md) to prepare training & validation datasets 75 | - **Zero-shot QA** 76 | ```bash 77 | bash Evaluation/zeroshotqa/scripts/zeroshotqa_pipeline.sh 78 | ``` 79 | - **Video Generative Benchmark** 80 | ```bash 81 | bash Evaluation/scripts/videochatgpt_pipeline.sh 82 | ``` 83 |   84 | 85 | 86 | ## 💻 Training w/ RLAIF 87 | - Refer to the [RLAIF](./RLAIF) folder to train reward model, policy model, and do PPO 88 | 89 |   90 | 91 | ## 🔧 Data Generation 92 | **Available Soon** 93 | 94 |   95 | 96 | 97 | 98 | ## 📚 Citation 99 | ``` 100 | @inproceedings{ahnCYKC24, 101 | author = {Daechul Ahn and Yura Choi and Youngjae Yu and Dongyeop Kang and Jonghyun Choi}, 102 | title = {Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback}, 103 | booktitle = {ACL}, 104 | year = {2024} 105 | } 106 | ``` 107 |   108 | 109 | ## License 110 | - The majority of this project is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file. 111 | - The service is a research preview intended for non-commercial use only, subject to the model License of LLaMA 112 |   113 | 114 | ## Acknowledgement 115 | - [LLaVA](https://github.com/haotian-liu/LLaVA.git) 116 | - [LLaVA-RLHF](https://github.com/llava-rlhf/LLaVA-RLHF.git) 117 | - [VideoChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT.git) 118 | - [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA.git) 119 | -------------------------------------------------------------------------------- /RLAIF/README.md: -------------------------------------------------------------------------------- 1 | # RL from AI Feedback 2 | 3 | This RLAIF codebase is mainly adapted from the from the [LLaVA-RLHF](https://github.com/llava-rlhf/LLaVA-RLHF.git), which is adapted from the [SALMON](https://github.com/Edward-Sun/SALMON) codebase. 4 | 5 | ## 0. Setup 6 | 7 | Please refer to [`llava_setup`](../llava_setup) for instructions on how to set up the customized llava package. 8 | 9 | Additionally, you **should** run the following command to make sure the versions of some essential packages are correct: 10 | 11 | ```bash 12 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 13 | pip install deepspeed==0.9.5 14 | pip install peft==0.4.0 15 | pip install transformers==4.34.0 16 | pip install bitsandbytes==0.41.0 17 | pip install datasets 18 | ``` 19 | 20 | We use following **SFT checkpoints** at huggingface to initialize RM and Policy model. Follow process 1. and 2. to train the whole PPO process from SFT 21 | - [SNUMPR/vlm_sft_video_llava_13b](https://huggingface.co/SNUMPR/vlm_sft_video_llava_13b) -> initialize RM 22 | - [SNUMPR/vlm_sft_video_llava_7b](https://huggingface.co/SNUMPR/vlm_sft_video_llava_7b) -> initialize Policy model 23 | 24 | Or, you can use our lora weights of trained RM and Policy model, and directly progress to step 3. 25 | - [SNUMPR/vlm_rm_13b_lora](https://huggingface.co/SNUMPR/vlm_rm_13b_lora) 26 | - [SNUMPR/vlm_policy_init_7b_lora](https://huggingface.co/SNUMPR/vlm_policy_init_7b_lora) 27 | 28 | 29 | 30 | ## 1. Train the Reward Model 31 | **Note**: For both 7b and 13b policy models, we use the same 13b reward model. 32 | ```bash 33 | bash RLAIF/scripts/train_reward_model.sh \ 34 | openai/clip-vit-large-patch14-336 \ # vision tower init 35 | dataset/videos \ # path to videos 36 | dataset/RM_13b_v1_dataset_39k.json \ # training pref data 37 | dataset/RM_13b_v1_dataset_39k.json \ # validation pref data 38 | SNUMPR/vlm_sft_video_llava_13b \ # sft model 39 | checkpoints/Video_LLaVA_RM_13b_lora \ # path to save trained reward model 40 | ``` 41 | 42 | 43 | ## 2. Initialize the Policy Model 44 | ```bash 45 | bash RLAIF/scripts/initialize_policy_model.sh \ 46 | openai/clip-vit-large-patch14-336 \ # vision tower init 47 | dataset/videos \ # path to videos 48 | dataset/PPO_init.json \ # training data 49 | SNUMPR/vlm_sft_video_llava_7b \ # sft model 50 | checkpoints/Video_LLaVA_Policy_Init_7b_lora \ # path to save policy model init 51 | ``` 52 | 53 | ## 3. Train the RL Model with PPO 54 | #### Using your trained model 55 | ```bash 56 | bash RLAIF/scripts/train_rl_model.sh \ 57 | openai/clip-vit-large-patch14-336 \ # vision tower init 58 | dataset/videos \ # path to videos 59 | dataset/RL_data.json \ # training data 60 | SNUMPR/vlm_sft_video_llava_7b \ # sft model 61 | checkpoints/Video_LLaVA_Policy_Init_7b_lora \ # path to trained policy model lora 62 | checkpoints/Video_LLaVA_RM_13b_lora \ # path to trained reward model lora 63 | checkpoints/Video_LLaVA_RLAIF_7b \ # path to save ppo model 64 | True \ # use latest checkpoint of reward model 65 | ``` 66 | 67 | #### Using provided RM and policy init 68 | 1. Download lora weights for RM & initialized Policy model 69 | ```bash 70 | cd checkpoints 71 | git clone https://huggingface.co/SNUMPR/vlm_policy_init_7b_lora # Clone policy init lora 72 | git clone https://huggingface.co/SNUMPR/vlm_rm_video_llava_13b_lora # Clone reward model lora 73 | ``` 74 | 2. Train w/ PPO 75 | ```bash 76 | bash RLAIF/scripts/train_rl_model.sh \ 77 | openai/clip-vit-large-patch14-336 \ # vision tower init 78 | dataset/videos \ # path to videos 79 | dataset/RL_data.json \ # training data 80 | SNUMPR/vlm_sft_video_llava_7b \ # sft model 81 | checkpoints/vlm_policy_init_7b_lora \ # path to trained policy model lora 82 | checkpoints/vlm_rm_video_llava_13b_lora \ # path to trained reward model lora 83 | checkpoints/Video_LLaVA_RLAIF_7b \ # path to save ppo model 84 | False \ # use provided checkpoint of reward model 85 | ``` -------------------------------------------------------------------------------- /RLAIF/data_utils/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The LLaVA-RLHF Team 2 | # Copyright 2023 The Alpaca Team 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # PPO Constants 17 | from enum import Enum 18 | 19 | FACTUAL_PROMPT = "Specifically, the AI's response should be fully supported by the combination of the following captions:\n" 20 | 21 | class AnswerType(Enum): 22 | GENERAL = 1 23 | A_IN_ABCD = 2 24 | B_IN_ABCD = 3 25 | C_IN_ABCD = 4 26 | D_IN_ABCD = 5 27 | NO_IN_YESNO = 6 28 | YES_IN_YESNO = 7 29 | -------------------------------------------------------------------------------- /RLAIF/lora_utils.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the MIT license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | import glob 5 | import os 6 | from os.path import exists, join, isdir 7 | import shutil 8 | import sys 9 | from typing import Optional, Dict, Sequence, List 10 | 11 | import torch 12 | import transformers 13 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 14 | 15 | from models.reward_model import RewardModel 16 | 17 | DEFAULT_PAD_TOKEN = "[PAD]" 18 | 19 | 20 | class SavePeftModelCallback(transformers.TrainerCallback): 21 | def save_model(self, args, state, kwargs): 22 | print("Saving PEFT checkpoint...") 23 | 24 | global_rank = int(os.environ.get("RANK", 0)) 25 | 26 | if global_rank == 0: 27 | print("Saving model checkpoint to %s" % args.output_dir) 28 | if state.best_model_checkpoint is not None: 29 | checkpoint_folder = state.best_model_checkpoint 30 | else: 31 | checkpoint_folder = os.path.join( 32 | args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}" 33 | ) 34 | 35 | peft_model_path = os.path.join(checkpoint_folder, "adapter_model") 36 | reward_head_path = os.path.join(checkpoint_folder, "reward_head") 37 | 38 | if isinstance(kwargs["model"], RewardModel): 39 | kwargs["model"].backbone_model.save_pretrained(peft_model_path) 40 | torch.save( 41 | kwargs["model"].reward_head.state_dict(), 42 | reward_head_path, 43 | ) 44 | else: 45 | kwargs["model"].save_pretrained(peft_model_path) 46 | 47 | pytorch_model_paths = glob.glob( 48 | os.path.join(checkpoint_folder, "pytorch_model*.bin") 49 | ) 50 | for pytorch_model_path in pytorch_model_paths: 51 | if os.path.exists(pytorch_model_path): 52 | os.remove(pytorch_model_path) 53 | 54 | optimizer_path = os.path.join(checkpoint_folder, "optimizer.pt") 55 | if os.path.exists(optimizer_path): 56 | os.remove(optimizer_path) 57 | 58 | else: 59 | print("Skipping PEFT checkpoint save on rank %d" % global_rank) 60 | 61 | def on_save(self, args, state, control, **kwargs): 62 | self.save_model(args, state, kwargs) 63 | return control 64 | 65 | def on_train_end(self, args, state, control, **kwargs): 66 | def touch(fname, times=None): 67 | global_rank = int(os.environ.get("RANK", 0)) 68 | if global_rank == 0: 69 | with open(fname, "a"): 70 | os.utime(fname, times) 71 | 72 | touch(join(args.output_dir, "completed")) 73 | self.save_model(args, state, kwargs) 74 | 75 | 76 | def print_trainable_parameters(args, model): 77 | """ 78 | Prints the number of trainable parameters in the model. 79 | """ 80 | trainable_params = 0 81 | all_param = 0 82 | for _, param in model.named_parameters(): 83 | all_param += param.numel() 84 | if param.requires_grad: 85 | trainable_params += param.numel() 86 | if args.bits == 4: 87 | trainable_params /= 2 88 | print( 89 | f"trainable params: {trainable_params} || " 90 | f"all params: {all_param} || " 91 | f"trainable: {100 * trainable_params / all_param}" 92 | ) 93 | 94 | 95 | def get_last_checkpoint(checkpoint_dir): 96 | if isdir(checkpoint_dir): 97 | is_completed = exists(join(checkpoint_dir, "completed")) 98 | if is_completed: 99 | return None, True # already finished 100 | max_step = 0 101 | for filename in os.listdir(checkpoint_dir): 102 | if isdir(join(checkpoint_dir, filename)) and filename.startswith( 103 | "checkpoint" 104 | ): 105 | max_step = max(max_step, int(filename.replace("checkpoint-", ""))) 106 | if max_step == 0: 107 | return None, is_completed # training started, but no checkpoint 108 | checkpoint_dir = join(checkpoint_dir, f"checkpoint-{max_step}") 109 | print(f"Found a previous checkpoint at: {checkpoint_dir}") 110 | return checkpoint_dir, is_completed # checkpoint found! 111 | return None, False # first training 112 | -------------------------------------------------------------------------------- /RLAIF/models/distributed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Alpaca Team 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utilities for PyTorch's distributed training. 16 | 17 | Compatible with torchrun / elastic. 18 | 19 | Internal map: 20 | https://github.com/lxuechen/ml-swissknife/blob/main/ml_swissknife/distributed_utils.py 21 | """ 22 | 23 | import os 24 | import sys 25 | from typing import Optional 26 | 27 | import torch 28 | import torch.distributed as dist 29 | 30 | 31 | def setup(rank: Optional[int] = None, world_size: Optional[int] = None): 32 | if rank is None: 33 | rank = get_rank() 34 | if world_size is None: 35 | world_size = get_world_size() 36 | 37 | if world_size <= 1: 38 | return rank, world_size 39 | 40 | if not dist.is_initialized(): 41 | if sys.platform == "win32": 42 | # Distributed package only covers collective communications with Gloo 43 | # backend and FileStore on Windows platform. Set init_method parameter 44 | # in init_process_group to a local file. 45 | # Example init_method="file:///f:/libtmp/some_file" 46 | init_method = "file:///f:/libtmp/dist-tmp" 47 | dist.init_process_group( 48 | backend="gloo", 49 | init_method=init_method, 50 | rank=rank, 51 | world_size=world_size, 52 | ) 53 | elif torch.cuda.is_available(): 54 | dist.init_process_group(backend="nccl", rank=rank, world_size=world_size) 55 | else: 56 | dist.init_process_group(backend="gloo", rank=rank, world_size=world_size) 57 | 58 | return rank, world_size 59 | 60 | 61 | def cleanup(): 62 | dist.destroy_process_group() 63 | 64 | 65 | def get_rank(): 66 | return int(os.getenv("RANK", 0)) 67 | 68 | 69 | def get_local_rank(): 70 | return int(os.getenv("LOCAL_RANK", 0)) 71 | 72 | 73 | def get_world_size(): 74 | return int(os.getenv("WORLD_SIZE", 1)) 75 | 76 | 77 | def should_save(): 78 | """Return True if the current process is the main process.""" 79 | return get_rank() <= 0 80 | 81 | 82 | def all_gather_and_cat(tensor: torch.Tensor, dim=0): 83 | if get_world_size() > 1: 84 | tensor_list = [torch.empty_like(tensor) for _ in range(get_world_size())] 85 | dist.all_gather(tensor_list, tensor) 86 | tensor = torch.cat(tensor_list, dim=dim) 87 | return tensor 88 | 89 | 90 | is_main_process = should_save 91 | -------------------------------------------------------------------------------- /RLAIF/models/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The Alpaca Team 2 | # Copyright 2022 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Optional 17 | 18 | from torch import nn, optim 19 | from transformers import Trainer 20 | from transformers.optimization import get_scheduler 21 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 22 | from transformers.trainer_pt_utils import get_parameter_names 23 | 24 | 25 | def create_optimizer( 26 | args, model: nn.Module, optimizer: Optional[optim.Optimizer] = None 27 | ): 28 | """Create optimizer for trainer. 29 | 30 | This is detached version of the `Trainer.create_optimizer` method. 31 | We don't support sagemaker and fairscale for simplicity. 32 | 33 | Reference: 34 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 35 | """ 36 | opt_model = model 37 | 38 | if optimizer is None: 39 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 40 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 41 | optimizer_grouped_parameters = [ 42 | { 43 | "params": [ 44 | p 45 | for n, p in opt_model.named_parameters() 46 | if (n in decay_parameters and p.requires_grad) 47 | ], 48 | "weight_decay": args.weight_decay, 49 | }, 50 | { 51 | "params": [ 52 | p 53 | for n, p in opt_model.named_parameters() 54 | if (n not in decay_parameters and p.requires_grad) 55 | ], 56 | "weight_decay": 0.0, 57 | }, 58 | ] 59 | 60 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args) 61 | 62 | optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 63 | return optimizer 64 | 65 | 66 | def create_scheduler(args, optimizer, lr_scheduler, num_training_steps): 67 | """Create scheduler for trainer. 68 | 69 | This is detached version of the `Trainer.create_scheduler` method. 70 | 71 | Reference: 72 | https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py 73 | """ 74 | if lr_scheduler is None: 75 | lr_scheduler = get_scheduler( 76 | args.lr_scheduler_type, 77 | optimizer=optimizer, 78 | num_warmup_steps=args.get_warmup_steps(num_training_steps), 79 | num_training_steps=num_training_steps, 80 | ) 81 | return lr_scheduler 82 | -------------------------------------------------------------------------------- /RLAIF/prompts/fact_rlaif_reward_prompt_video.txt: -------------------------------------------------------------------------------- 1 | USER: Please evaluate the quality of your last response. There are several dimensions you should consider in your evaluation: 2 | 3 | 1. Accurate: The AI should provide factual and accurate information from the video, and refrain from making statements that are not supported by the video or inconsistent with the video. {factual_prompt} 4 | 2. Helpful: The AI’s response should precisely serve the user's needs and interests, while grounding the response in the video. 5 | 3. Language Natural: The AI should employ language that flows smoothly and is free from repetitive or awkward constructs. 6 | 4. Concise: The AI should efficiently address the task or answer the question, communicating the necessary information with brevity and clarity. 7 | 8 | A good response should be accurate, helpful, language natural, and concise. ASSISTANT: Following your definitions, the quality score of my last response is -------------------------------------------------------------------------------- /RLAIF/scripts/initialize_policy_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 7 | export PYTHONPATH="$PWD:$PYTHONPATH" 8 | export GPUS_PER_NODE=8 9 | export OMP_NUM_THREADS=8 10 | 11 | # MODEL CONFIG 12 | VISION_TOWER=$1 13 | VIDEO_PATH=$2 14 | DATA_PATH=$3 15 | BASE_MODEL_PATH=$4 16 | POLICY_SAVE_PATH=$5 17 | 18 | # TRAINING CONFIG 19 | NUM_EPOCHS=1 20 | LEARNING_RATE=1e-4 21 | BATCH_SIZE=8 22 | GRAD_ACCUMULATION=2 23 | 24 | deepspeed \ 25 | finetune_policy_init.py \ 26 | --deepspeed scripts/zero2.json \ 27 | --do_train \ 28 | --do_eval \ 29 | --seed 42 \ 30 | --per_device_train_batch_size $BATCH_SIZE \ 31 | --per_device_eval_batch_size 8 \ 32 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 33 | --model_name_or_path $BASE_MODEL_PATH \ 34 | --image_folder $VIDEO_PATH \ 35 | --vision_tower $VISION_TOWER \ 36 | --learning_rate $LEARNING_RATE \ 37 | --mm_vision_select_layer -2 \ 38 | --mm_use_im_start_end False \ 39 | --mm_use_im_patch_token False \ 40 | --freeze_mm_mlp_adapter True \ 41 | --query_len 1280 \ 42 | --response_len 768 \ 43 | --dataset $DATA_PATH \ 44 | --dataset_format "v1" \ 45 | --eval_size 500 \ 46 | --bits 16 \ 47 | --lora_r 64 \ 48 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 49 | --output_dir $POLICY_SAVE_PATH \ 50 | --num_train_epochs $NUM_EPOCHS \ 51 | --group_by_length False \ 52 | --evaluation_strategy "steps" \ 53 | --eval_steps 50 \ 54 | --save_strategy "steps" \ 55 | --save_steps 1000000 \ 56 | --save_total_limit 1 \ 57 | --weight_decay 0.0 \ 58 | --warmup_ratio 0.03 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 5 \ 61 | --report_to "tensorboard" \ 62 | --ddp_backend "nccl" \ 63 | --bf16 True \ 64 | --ddp_find_unused_parameters False \ 65 | --resume_from_training True \ 66 | --image_aspect_ratio 'pad' 67 | -------------------------------------------------------------------------------- /RLAIF/scripts/parse_largest_ckptname.sh: -------------------------------------------------------------------------------- 1 | 2 | MODEL_DIR="/dataset/llms/LLaVA_RLHF/LLaVA_Video-RLHF/pretrained" 3 | SFT_MODEL_NAME=llava-v1.5-7b-lora_w_lora_16_sftv2_short1632_and_then_long_rank32_alpha32_lr1e4 4 | CKPTS_DIR=$SFT_MODEL_NAME"_allmodels" 5 | RM_LORA_PATH=$CKPTS_DIR/RM_v2data 6 | 7 | 8 | largest_number=0 9 | largest_directory="" 10 | 11 | echo $MODEL_DIR/$RM_LORA_PATH/ 12 | for directory in $MODEL_DIR/$RM_LORA_PATH/checkpoint-*; do 13 | # Check if the entry is a directory 14 | if [ -d "$directory" ]; then 15 | # Extract the number from the directory name 16 | # number=$(basename "$directory" | sed 's/checkpoint-//') 17 | number=$(basename "$directory" | sed 's/[^0-9]*//g') 18 | 19 | # Compare the number with the largest number found so far 20 | if [ "$number" -gt "$largest_number" ]; then 21 | largest_number="$number" 22 | # largest_directory="$directory" 23 | largest_directory=$(basename "$directory") 24 | fi 25 | fi 26 | done 27 | 28 | 29 | echo "Largest Directory: $largest_directory" -------------------------------------------------------------------------------- /RLAIF/scripts/train_reward_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 7 | export PYTHONPATH="$PWD:$PYTHONPATH" 8 | export GPUS_PER_NODE=8 9 | export OMP_NUM_THREADS=8 10 | 11 | VISION_TOWER=$1 12 | VIDEO_PATH=$2 13 | PREFERENCE_DATA_PATH=$3 14 | PREFERENCE_EVAL_DATA_PATH=$4 15 | SFT_MODEL_PATH=$5 16 | RM_SAVE_PATH=$6 17 | 18 | 19 | # TRAINING CONFIG 20 | NUM_EPOCHS=1 21 | LEARNING_RATE=2e-5 22 | BATCH_SIZE=2 23 | GRAD_ACCUMULATION=1 24 | 25 | torchrun \ 26 | --standalone \ 27 | --nnodes=1 \ 28 | --nproc-per-node=$GPUS_PER_NODE \ 29 | RLHF/finetune_lora_rm.py \ 30 | --do_train \ 31 | --do_eval \ 32 | --seed 42 \ 33 | --per_device_train_batch_size $BATCH_SIZE \ 34 | --per_device_eval_batch_size $BATCH_SIZE \ 35 | --gradient_accumulation_steps $GRAD_ACCUMULATION \ 36 | --model_name_or_path $SFT_MODEL_PATH \ 37 | --image_folder $VIDEO_PATH/ \ 38 | --vision_tower $VISION_TOWER \ 39 | --learning_rate $LEARNING_RATE \ 40 | --mm_vision_select_layer -2 \ 41 | --mm_use_im_start_end False \ 42 | --mm_use_im_patch_token False \ 43 | --freeze_mm_mlp_adapter True \ 44 | --model_max_length 2048 \ 45 | --query_len 1280 \ 46 | --response_len 768 \ 47 | --dataset_path $PREFERENCE_DATA_PATH \ 48 | --eval_dataset_path $PREFERENCE_EVAL_DATA_PATH \ 49 | --dataset_name "none" \ 50 | --eval_dataset_name "none" \ 51 | --eval_size 500 \ 52 | --bits 16 \ 53 | --lora_r 64 \ 54 | --lora_modules q_proj k_proj v_proj o_proj gate_proj up_proj down_proj \ 55 | --output_dir $RM_SAVE_PATH \ 56 | --num_train_epochs $NUM_EPOCHS \ 57 | --group_by_length False \ 58 | --evaluation_strategy "steps" \ 59 | --eval_steps 400 \ 60 | --save_strategy "steps" \ 61 | --save_steps 400 \ 62 | --save_total_limit 5 \ 63 | --weight_decay 0.0 \ 64 | --warmup_ratio 0.03 \ 65 | --lr_scheduler_type "constant_with_warmup" \ 66 | --logging_steps 10 \ 67 | --report_to "tensorboard" \ 68 | --ddp_backend "nccl" \ 69 | --bf16 True \ 70 | --ddp_find_unused_parameters False \ 71 | --resume_from_training True \ 72 | --reward_prompt_file "./prompts/fact_rlaif_reward_prompt_video.txt" \ 73 | --image_aspect_ratio 'pad' 74 | -------------------------------------------------------------------------------- /RLAIF/scripts/train_rl_model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 7 | export MODEL_DIR="/dataset/llms/LLaVA_RLHF/LLaVA_Video-RLHF/pretrained" 8 | export PYTHONPATH="$PWD:$PYTHONPATH" 9 | export GPUS_PER_NODE=8 10 | export OMP_NUM_THREADS=8 11 | export TRANSFORMERS_OFFLINE=1 12 | 13 | # ====================================== CHANGE HERE ====================================== 14 | VISION_TOWER=$1 15 | VIDEO_PATH=$2 16 | DATA_PATH=$3 17 | SFT_MODEL_PATH=$4 18 | POLICY_INIT_PATH=$5 19 | RM_MODEL_PATH=$6 20 | RLHF_SAVE_PATH=$7 21 | CALC_RM_CKPT=$8 22 | 23 | # TRAINING CONFIG 24 | LEARNING_RATE=3e-5 25 | KL_COEF=0.1 26 | EPOCH=1 27 | ROLLOUT_BATCH_SIZE=256 28 | STEP_BATCH_SZIE=128 29 | ROLLOUT_PER_DEVICE_BATCH_SIZE=16 30 | REWARD_MODEL_PER_DEVICE_BATCH_SIZE=8 31 | STEP_PER_DEVICE_BATCH_SIZE=8 32 | NOPTEPOCHS=2 33 | 34 | # FACT-RLHF CONFIG 35 | INCOMPLETE_RESPONSE=-8.0 36 | LENGTH_BONUS=-10.0 37 | CORRECT_BONUS=2.0 38 | # ========================================================================================== 39 | 40 | # Get Largest RM PATH or use provided checkpoint 41 | if [ "$CALC_RM_CKPT" = true ]; then 42 | largest_num=0 43 | RM_CKPT_NAME="" 44 | for directory in $RM_MODEL_PATH/checkpoint-*; do 45 | # Check if the entry is a directory 46 | if [ -d "$directory" ]; then 47 | # Extract the number from the directory name 48 | number=$(basename "$directory" | sed 's/[^0-9]*//g') 49 | # Compare the number with the largest number found so far 50 | if [ "$number" -gt "$largest_number" ]; then 51 | largest_number="$number" 52 | # largest_directory="$directory" 53 | RM_CKPT_NAME=$(basename "$directory") 54 | fi 55 | fi 56 | done 57 | RM_MODEL_PATH=$RM_MODEL_PATH/$RM_CKPT_NAME 58 | else 59 | RM_MODEL_PATH=$RM_MODEL_PATH 60 | fi 61 | 62 | 63 | torchrun \ 64 | --standalone \ 65 | --nnodes=1 \ 66 | --nproc-per-node=$GPUS_PER_NODE \ 67 | finetune_lora_ppo.py \ 68 | --do_train \ 69 | --seed 42 \ 70 | --step_batch_size $STEP_BATCH_SZIE \ 71 | --step_per_device_batch_size $STEP_PER_DEVICE_BATCH_SIZE \ 72 | --rollout_batch_size $ROLLOUT_BATCH_SIZE \ 73 | --rollout_per_device_batch_size $ROLLOUT_PER_DEVICE_BATCH_SIZE \ 74 | --reward_model_per_device_batch_size $REWARD_MODEL_PER_DEVICE_BATCH_SIZE \ 75 | --base_model_name $SFT_MODEL_PATH \ 76 | --policy_model_name_or_path $POLICY_INIT_PATH \ 77 | --reward_model_name_or_path $RM_MODEL_PATH \ 78 | --learning_rate $LEARNING_RATE \ 79 | --init_value_with_reward True \ 80 | --warmup_steps 5 \ 81 | --dataset_path $DATA_PATH \ 82 | --train_splits "train" \ 83 | --output_dir $RLHF_SAVE_PATH \ 84 | --total_epochs $EPOCH \ 85 | --group_by_length False \ 86 | --evaluation_strategy "no" \ 87 | --save_strategy "steps" \ 88 | --save_steps 10 \ 89 | --save_total_limit 100000 \ 90 | --weight_decay 0.0 \ 91 | --lr_scheduler_type "cosine" \ 92 | --logging_steps 1 \ 93 | --report_to "tensorboard" \ 94 | --ddp_backend "nccl" \ 95 | --bf16 True \ 96 | --penalty_reward_value $INCOMPLETE_RESPONSE \ 97 | --length_bonus_score $LENGTH_BONUS \ 98 | --correct_bonus_score $CORRECT_BONUS \ 99 | --relative_stop_token_penalty True \ 100 | --penalize_no_stop_token True \ 101 | --ddp_find_unused_parameters False \ 102 | --resume_from_training True \ 103 | --kl_coef $KL_COEF \ 104 | --max_grad_norm 1.0 \ 105 | --whitening_async_stats "full_batch" \ 106 | --clean_tokens_after_eos True \ 107 | --temperature 1.0 \ 108 | --whiten_rewards False \ 109 | --model_max_length 2048 \ 110 | --query_len 128 \ 111 | --response_len 896 \ 112 | --noptepochs $NOPTEPOCHS \ 113 | --image_folder $VIDEO_PATH \ 114 | --vision_tower different \ 115 | --mm_vision_select_layer -2 \ 116 | --mm_use_im_start_end False \ 117 | --mm_use_im_patch_token False \ 118 | --freeze_mm_mlp_adapter True \ 119 | --reward_prompt_file "./prompts/fact_rlaif_reward_prompt_video.txt" \ 120 | --image_aspect_ratio 'pad' -------------------------------------------------------------------------------- /RLAIF/scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "train_micro_batch_size_per_gpu": "auto", 6 | "train_batch_size": "auto", 7 | "gradient_accumulation_steps": "auto", 8 | "zero_optimization": { 9 | "stage": 2, 10 | "overlap_comm": true, 11 | "contiguous_gradients": true, 12 | "sub_group_size": 1e9, 13 | "reduce_bucket_size": "auto" 14 | } 15 | } -------------------------------------------------------------------------------- /RLAIF_DataGen/README.md: -------------------------------------------------------------------------------- 1 | # Preference Data Generation for RLAIF 2 | - Response generation: A / B 3 | - Context generation 4 | - Preference selection 5 | 6 | [Coming Soon] 7 | -------------------------------------------------------------------------------- /SFT/README.md: -------------------------------------------------------------------------------- 1 | # Supervised Fine-Tuning 2 | - Two-staged supervised fine-tuning 3 | 4 | [Coming Soon] 5 | -------------------------------------------------------------------------------- /assets/images/rlaif_feedback_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/assets/images/rlaif_feedback_teaser.png -------------------------------------------------------------------------------- /llava_setup/.gitignore: -------------------------------------------------------------------------------- 1 | LLaVA -------------------------------------------------------------------------------- /llava_setup/README.md: -------------------------------------------------------------------------------- 1 | # Install LLaVA 2 | 3 | We use LLaVA version `6cea223` for training the SFT and RLAIF models, as in LLaVA-RLHF. 4 | 5 | ## Apply the custom patch 6 | 7 | ```bash 8 | git clone https://github.com/haotian-liu/LLaVA.git 9 | 10 | cd LLaVA 11 | 12 | git reset --hard 6cea223 13 | 14 | git apply < ../fix_llava_padding.patch 15 | ``` 16 | 17 | ## Install LLaVA 18 | 19 | Next, please follow the instructions in the [original repository](https://github.com/haotian-liu/LLaVA/tree/6cea223532a7ab7bda8116336c59772faccdcbca#install) to install LLaVA. 20 | 21 | ## Update Packages 22 | 23 | Finally, please update the following packages: 24 | 25 | ```bash 26 | pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118 27 | pip install deepspeed==0.9.3 28 | pip install peft==0.4.0 29 | pip install transformers==4.31.0 30 | pip install bitsandbytes==0.41.0 31 | pip install datasets 32 | ``` 33 | 34 | **Note:** please install Pytorch 2.0.1 following the guidelines [here](https://pytorch.org/get-started/previous-versions/#v201). We found that the flash-attention implementation in the newest Pytorch Stable (2.1.0) could lead to buggy results. The codebase is tested with `torch==2.0.1+cu118`. 35 | -------------------------------------------------------------------------------- /serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yonseivnl/vlm-rlaif/610cac72c16ae3862b0d83a0dc3708d73a42a4cc/serve/__init__.py -------------------------------------------------------------------------------- /serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | 6 | from videollava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, \ 7 | DEFAULT_VIDEO_TOKEN 8 | from videollava.conversation import conv_templates, SeparatorStyle 9 | from videollava.model.builder import load_pretrained_model 10 | from videollava.serve.utils import load_image, image_ext, video_ext 11 | from videollava.utils import disable_torch_init 12 | from videollava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | 16 | import requests 17 | from PIL import Image 18 | from io import BytesIO 19 | from transformers import TextStreamer 20 | 21 | 22 | 23 | 24 | 25 | def main(args): 26 | # Model 27 | disable_torch_init() 28 | 29 | model_name = get_model_name_from_path(args.model_path) 30 | tokenizer, model, processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, 31 | args.load_8bit, args.load_4bit, 32 | device=args.device, cache_dir=args.cache_dir) 33 | image_processor, video_processor = processor['image'], processor['video'] 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | tensor = [] 55 | special_token = [] 56 | args.file = args.file if isinstance(args.file, list) else [args.file] 57 | for file in args.file: 58 | if os.path.splitext(file)[-1].lower() in image_ext: 59 | file = image_processor.preprocess(file, return_tensors='pt')['pixel_values'][0].to(model.device, dtype=torch.float16) 60 | special_token += [DEFAULT_IMAGE_TOKEN] 61 | elif os.path.splitext(file)[-1].lower() in video_ext: 62 | file = video_processor(file, return_tensors='pt')['pixel_values'][0].to(model.device, dtype=torch.float16) 63 | special_token += [DEFAULT_IMAGE_TOKEN] * model.get_video_tower().config.num_frames 64 | else: 65 | raise ValueError(f'Support video of {video_ext} and image of {image_ext}, but found {os.path.splitext(file)[-1].lower()}') 66 | print(file.shape) 67 | tensor.append(file) 68 | 69 | 70 | 71 | 72 | while True: 73 | try: 74 | inp = input(f"{roles[0]}: ") 75 | except EOFError: 76 | inp = "" 77 | if not inp: 78 | print("exit...") 79 | break 80 | 81 | print(f"{roles[1]}: ", end="") 82 | 83 | if file is not None: 84 | # first message 85 | if getattr(model.config, "mm_use_im_start_end", False): 86 | inp = ''.join([DEFAULT_IM_START_TOKEN + i + DEFAULT_IM_END_TOKEN for i in special_token]) + '\n' + inp 87 | else: 88 | inp = ''.join(special_token) + '\n' + inp 89 | conv.append_message(conv.roles[0], inp) 90 | file = None 91 | else: 92 | # later messages 93 | conv.append_message(conv.roles[0], inp) 94 | conv.append_message(conv.roles[1], None) 95 | prompt = conv.get_prompt() 96 | 97 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 98 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 99 | keywords = [stop_str] 100 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 101 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 102 | 103 | with torch.inference_mode(): 104 | output_ids = model.generate( 105 | input_ids, 106 | images=tensor, # video as fake images 107 | do_sample=True if args.temperature > 0 else False, 108 | temperature=args.temperature, 109 | max_new_tokens=args.max_new_tokens, 110 | streamer=streamer, 111 | use_cache=True, 112 | stopping_criteria=[stopping_criteria]) 113 | 114 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 115 | conv.messages[-1][-1] = outputs 116 | 117 | if args.debug: 118 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("--model-path", type=str, default="LanguageBind/Video-LLaVA-7B") 124 | parser.add_argument("--model-base", type=str, default=None) 125 | parser.add_argument("--cache-dir", type=str, default=None) 126 | parser.add_argument("--file", nargs='+', type=str, required=True) 127 | parser.add_argument("--device", type=str, default="cuda") 128 | parser.add_argument("--conv-mode", type=str, default=None) 129 | parser.add_argument("--temperature", type=float, default=0.2) 130 | parser.add_argument("--max-new-tokens", type=int, default=512) 131 | parser.add_argument("--load-8bit", action="store_true") 132 | parser.add_argument("--load-4bit", action="store_true") 133 | parser.add_argument("--debug", action="store_true") 134 | args = parser.parse_args() 135 | main(args) 136 | -------------------------------------------------------------------------------- /serve/gradio_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import TextStreamer 3 | 4 | from videollava.constants import IMAGE_TOKEN_INDEX 5 | from videollava.conversation import conv_templates, SeparatorStyle 6 | from videollava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token 7 | from videollava.model.builder import load_pretrained_model 8 | from videollava.utils import disable_torch_init 9 | 10 | title_markdown = (""" 11 |
12 | 13 | Video-LLaVA🚀 14 | 15 |
16 |

Video-LLaVA: Learning United Visual Representation by Alignment Before Projection

17 |
If you like our project, please give us a star ✨ on Github for the latest update.
18 |
19 |
20 | 21 | 22 |
23 |
24 | 25 | 26 | 27 |
28 |
29 | """) 30 | 31 | block_css = """ 32 | #buttons button { 33 | min-width: min(120px,100%); 34 | } 35 | """ 36 | 37 | tos_markdown = (""" 38 | ### Terms of use 39 | By using this service, users are required to agree to the following terms: 40 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. 41 | Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. 42 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. 43 | """) 44 | 45 | learn_more_markdown = (""" 46 | ### License 47 | The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. 48 | """) 49 | 50 | 51 | class Chat: 52 | def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None): 53 | disable_torch_init() 54 | model_name = get_model_name_from_path(model_path) 55 | self.tokenizer, self.model, processor, context_len = load_pretrained_model(model_path, model_base, model_name, 56 | load_8bit, load_4bit, 57 | device=device, cache_dir=cache_dir) 58 | self.image_processor = processor['image'] 59 | self.video_processor = processor['video'] 60 | self.conv_mode = conv_mode 61 | self.conv = conv_templates[conv_mode].copy() 62 | self.device = self.model.device 63 | print(self.model) 64 | 65 | def get_prompt(self, qs, state): 66 | state.append_message(state.roles[0], qs) 67 | state.append_message(state.roles[1], None) 68 | return state 69 | 70 | @torch.inference_mode() 71 | def generate(self, images_tensor: list, prompt: str, first_run: bool, state): 72 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 73 | 74 | state = self.get_prompt(prompt, state) 75 | prompt = state.get_prompt() 76 | # print('\n\n\n') 77 | # print(prompt) 78 | 79 | 80 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) 81 | 82 | temperature = 0.2 83 | 84 | max_new_tokens = 1024 85 | 86 | stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 87 | keywords = [stop_str] 88 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 89 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 90 | print(prompt, input_ids, len(images_tensor), images_tensor[0].shape) 91 | with torch.inference_mode(): 92 | output_ids = model.generate( 93 | input_ids, 94 | images=images_tensor, 95 | do_sample=True, 96 | temperature=temperature, 97 | max_new_tokens=max_new_tokens, 98 | streamer=streamer, 99 | use_cache=True, 100 | stopping_criteria=[stopping_criteria]) 101 | 102 | input_token_len = input_ids.shape[1] 103 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 104 | if n_diff_input_output > 0: 105 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 106 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 107 | outputs = outputs.strip() 108 | if outputs.endswith(stop_str): 109 | outputs = outputs[:-len(stop_str)] 110 | outputs = outputs.strip() 111 | 112 | print('response', outputs) 113 | return outputs, state 114 | -------------------------------------------------------------------------------- /serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from videollava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /serve/utils.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import requests 4 | from PIL import Image 5 | 6 | 7 | def load_image(image_file): 8 | if image_file.startswith('http://') or image_file.startswith('https://'): 9 | response = requests.get(image_file) 10 | image = Image.open(BytesIO(response.content)).convert('RGB') 11 | else: 12 | image = Image.open(image_file).convert('RGB') 13 | return image 14 | 15 | video_ext = ['.mp4', '.mov', '.mkv', '.avi'] 16 | image_ext = ['.jpg', '.png', '.bmp', '.jpeg'] 17 | --------------------------------------------------------------------------------