├── .gitignore ├── CodeFuse-VLM-14B-performance.png ├── CodeFuse-VLM-arch.png ├── CodeFuse_UserGroup.png ├── LEGAL.md ├── LICENSE ├── MFT-VLM-arch.png ├── README.md ├── README_cn.md ├── accelerate_ds_config.yaml ├── demo.ipynb ├── init_env.sh ├── llava ├── .DS_Store ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── .DS_Store │ ├── GPT_eval_llava_bench.py │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_gqa.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_vizwiz.py │ ├── generate_webpage_data_from_table.py │ ├── llava_bench_vqa.py │ ├── mmbench_vqa.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_science.py │ ├── model_vqa_yuque.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ ├── textvqa_bench_vqa.py │ ├── vqa_sketch2code.py │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── merge_pretrain_cross_attn_to_qwenvl.py ├── merge_pretrain_weights_to_qwenvl.py ├── mm_utils.py ├── model │ ├── .DS_Store │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ ├── llava_qwen.py │ │ ├── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ │ └── qwen │ │ │ ├── configuration_qwen.py │ │ │ ├── modeling_qwen.py │ │ │ ├── qwen_generation_utils.py │ │ │ └── tokenization_qwen.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ └── visual.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── .DS_Store │ ├── __init__.py │ ├── assets │ │ ├── android-dsl-mapping.json │ │ ├── ios-dsl-mapping.json │ │ └── web-dsl-mapping.json │ ├── classes │ │ ├── Compiler.py │ │ ├── Node.py │ │ ├── Utils.py │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── Compiler.cpython-38.pyc │ │ │ ├── Node.cpython-38.pyc │ │ │ ├── Utils.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ ├── cli.py │ ├── controller.py │ ├── demo.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── serve │ │ ├── .dockerenv │ │ ├── .pouch.first │ │ └── .pouch_runc_init │ └── test_message.py ├── train │ ├── __pycache__ │ │ ├── llama_flash_attn_monkey_patch.cpython-38.pyc │ │ ├── llava_trainer.cpython-38.pyc │ │ └── train.cpython-38.pyc │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── playground ├── .DS_Store └── data │ ├── .DS_Store │ ├── coco2014_val_gpt4_qa_30x3.jsonl │ ├── coco2014_val_qa_eval │ ├── qa90_gpt4_answer.jsonl │ └── qa90_questions.jsonl │ └── prompts │ ├── complex_reasoning │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ ├── 002_caps.txt │ ├── 002_conv.txt │ └── system_message.txt │ ├── conversation │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ └── system_message.txt │ └── detail_description │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ ├── 002_caps.txt │ ├── 002_conv.txt │ └── system_message.txt ├── plot.ipynb ├── pyproject.toml └── scripts ├── acc_ds_config_zero3.json ├── convert_sqa_to_llava.py ├── convert_sqa_to_llava_base_prompt.py ├── finetune.sh ├── finetune_full_schedule.sh ├── finetune_lora.sh ├── finetune_multinode.sh ├── finetune_qlora.sh ├── finetune_yuque_qwen.sh ├── merge_lora_weights.py ├── merge_qwen_vl_weights.sh ├── new_ds_config.json ├── new_ds_config_zero3.json ├── pretrain.sh ├── pretrain_llava_qwen.sh ├── pretrain_multinode.sh ├── render_sketch2code.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Compiled Files 2 | *.pyc 3 | *.DS_Store 4 | -------------------------------------------------------------------------------- /CodeFuse-VLM-14B-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/CodeFuse-VLM-14B-performance.png -------------------------------------------------------------------------------- /CodeFuse-VLM-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/CodeFuse-VLM-arch.png -------------------------------------------------------------------------------- /CodeFuse_UserGroup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/CodeFuse_UserGroup.png -------------------------------------------------------------------------------- /LEGAL.md: -------------------------------------------------------------------------------- 1 | Legal Disclaimer 2 | 3 | Within this source code, the comments in Chinese shall be the original, governing version. Any comment in other languages are for reference only. In the event of any conflict between the Chinese language version comments and other language version comments, the Chinese language version shall prevail. 4 | 5 | 法律免责声明 6 | 7 | 关于代码注释部分,中文注释为官方版本,其它语言注释仅做参考。中文注释可能与其它语言注释存在不一致,当中文注释与其它语言注释存在不一致时,请以中文注释为准。 -------------------------------------------------------------------------------- /MFT-VLM-arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/MFT-VLM-arch.png -------------------------------------------------------------------------------- /README_cn.md: -------------------------------------------------------------------------------- 1 | ## CodeFuse-VLM 2 | CodeFuse-VLM 是一个多模态大语言模型框架,该框架为用户提供多种视觉编码器,模态对齐模块和大语言模型的选择,以适配用户对不同任务的需求。 3 | 4 | 随着huggingface开源社区的不断更新,会有更多的vision encoder 和 LLM 底座发布,这些vision encoder 和 LLM底座都有各自的强项,例如 code-llama 适合生成代码类任务,但是不适合生成中文类的任务;因此我们搭建了CodeFuse-VLM 框架,支持多种视觉模型和语言大模型,使得CodeFuse-VLM可以适应不同种类的任务。 5 | 6 | ![img.jpg](./CodeFuse-VLM-arch.png) 7 | 8 | 我们在CodeFuse-VLM 框架下, 使用Qwen-VL的视觉编码器, cross attention模态对齐模块, 和 Qwen-14B 模型训练了 CodeFuse-VLM-14B 9 | 10 | CodeFuse-VLM-14B 在多个benchmarks 上的性能超过了Qwen-VL和LLAVA-1.5 11 | ![img.jpg](./CodeFuse-VLM-14B-performance.png) 12 | 13 | 各个模型得分如下表所示: 14 | 模型 | MMBench | MMBench-CN | VqaV2 | GQA | TextVQA | Vizwiz 15 | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | 16 | LLAVA-1.5 | 67.7 | 63.6 | 80.0 | 63.3 | 61.3 | 53.6 17 | Qwen-VL | 60.6 | 56.7 | 78.2 | 57.5 | 63.8 | 38.9 18 | CodeFuse-VLM-14B | 75.7 | 69.8 | 79.3 | 59.4 | 63.9 | 45.3 19 | 20 | 我们的模型在MMBenchmark 多模态大模型榜单上取得了很高的排名: https://mmbench.opencompass.org.cn/leaderboard 21 | 22 | 这是我们模型的展示视频 23 | 24 | https://private-user-images.githubusercontent.com/22836551/300386230-8e64f615-ac0e-447e-9695-c96b254d484f.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MDY1MjExODksIm5iZiI6MTcwNjUyMDg4OSwicGF0aCI6Ii8yMjgzNjU1MS8zMDAzODYyMzAtOGU2NGY2MTUtYWMwZS00NDdlLTk2OTUtYzk2YjI1NGQ0ODRmLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDAxMjklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwMTI5VDA5MzQ0OVomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWQ5NzNjM2U1ZWU4NDU0Yzc5NmE4ZTM1NzY2ZjU4YjRjY2ZhNjMzODk0ZDgzMDg4N2FjYjZhYTllM2E3NTAyMWQmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.pr-ad7rKYBgk26DTItj2q2q9I5dRWnBNHbV9M7GSVCo 25 | 26 | 27 | ## Contents 28 | - [Install](#Install) 29 | - [Datasets](#Datasets) 30 | - [Multimodal Alignment](#Multimodal-Alignment) 31 | - [Visual Instruction Tuning](#Visual-Instruction-Tuning) 32 | - [Evaluation](#Evaluation) 33 | 34 | ## Install 35 | 请执行 sh init\_env.sh 36 | 37 | ## Datasets 38 | 使用了以下数据集训练模型: 39 | 40 | 数据集 | 任务种类 | 样本量 41 | | ------------- | ------------- | ------------- | 42 | synthdog-en | OCR | 800,000 43 | synthdog-zh | OCR | 800,000 44 | cc3m(downsampled)| Image Caption | 600,000 45 | cc3m(downsampled)| Image Caption | 600,000 46 | SBU | Image Caption | 850,000 47 | Visual Genome VQA (Downsampled) | Visual Question Answer(VQA) | 500,000 48 | Visual Genome Region descriptions (Downsampled) | Reference Grouding | 500,000 49 | Visual Genome objects (Downsampled) | Grounded Caption | 500,000 50 | OCR VQA (Downsampled) | OCR and VQA | 500,000 51 | 52 | 请到各个数据集的官网上下载这些数据。 53 | 54 | ## Multimodal Alignment 55 | 请执行 sh scripts/pretrain.sh 或者 sh scripts/pretrain\_multinode.sh 56 | 57 | 58 | ## Visual Instruction Tuning 59 | 请执行 sh scripts/finetune.sh 或者 sh scripts/finetune\_multinode.sh 60 | 61 | ## Evaluation 62 | 请执行 llava/eval/ 当中的python脚本. 可以通过下面的代码来加载我们预训练的CodeFuse-VLM-14B: 63 | 64 | ``` 65 | import os 66 | from llava.model.builder import load_mixed_pretrained_model 67 | 68 | model_path = '/pretrained/model/path' 69 | tokenizer, model, image_processor, context_len = load_mixed_pretrained_model(model_path, None, 'qwen-vl-14b', os.path.join(model_path, 'Qwen-VL-visual'), 'cross_attn', os.path.join(model_path, 'mm_projector/mm_projector.bin')) 70 | ``` 71 | 72 | 您也可以先运行下面的脚本来合并各个模型组件:scripts/merge\_qwen\_vl\_weights.sh,然后通过下面的代码加载合并后的模型: 73 | ``` 74 | from llava.model import LlavaQWenForCausalLM 75 | 76 | model = LlavaQWenForCausalLM.from_pretrained('/path/to/our/pretrained/model') 77 | ``` 78 | 79 | ## CodeFuse-VLM 产品视频 80 | 这是我们模型支持的产品的视频 81 | 82 | https://private-user-images.githubusercontent.com/22836551/300398424-201f667d-6b6b-4548-b3e6-724afc4b3071.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MDY1MjE5MTIsIm5iZiI6MTcwNjUyMTYxMiwicGF0aCI6Ii8yMjgzNjU1MS8zMDAzOTg0MjQtMjAxZjY2N2QtNmI2Yi00NTQ4LWIzZTYtNzI0YWZjNGIzMDcxLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDAxMjklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwMTI5VDA5NDY1MlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWI0ZmJmZWNlNDZmNWM3NzA0OThlMmY1ODY4MDkxNWY5ZWNiNzRiYjJkYmE4NjEzM2EwYWRiNWY2ODc3N2ViYjEmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.BIvWGNx0XV7RoauxB0c2noEdbfZfu8-16LPHtCaCJ9k 83 | 84 | ## 加入我们 85 | img 86 | 87 | 我们是平台技术事业群风险智能团队,负责蚂蚁蚂蚁集团平台工程的智能化,团队成立3年多以来,支持了蚂蚁集团云计算基础设施智能化运维的升级改造。团队的Mission是,通过世界级的技术创新和影响,构建有广泛用户的算法服务和平台,支撑内外部产品和业务落地。团队秉承创新基因,在支撑业务落地的同时,推动技术影响。3年以来在ICLR、NeurIPS、KDD、ACL等顶会发表论文20余篇,创新业务结果获得两次蚂蚁技术最高奖T-Star,1次蚂蚁集团最高奖SuperMA。开源项目CodeFuse获得4K点赞(2024年2月),Huggingface和modelscope上模型累积下载量超过150万次。 88 | 89 | **我们正在寻找行业中的佼佼者加入我们的团队!如果您希望在一个充满活力、创新和卓越文化的环境中发展您的职业生涯,欢迎您查看我们的社招&校招机会,加入我们,一起创造下一个行业里程碑。** 90 | 91 | **校招**:https://hrrecommend.antgroup.com/guide.html?code=8uoP5mlus5DqQYbE_EnqcE2FD5JZH21MwvMUIb9mb6X3osXPuBraG54SyM8GLn_7 92 | 93 | **社招**:https://talent.antgroup.com/off-campus-position?positionId=1933830 -------------------------------------------------------------------------------- /accelerate_ds_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_config_file: ./scripts/acc_ds_config.json 4 | zero3_init_flag: False 5 | # steps_per_print: 1 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | dynamo_backend: 'NO' 9 | fsdp_config: {} 10 | machine_rank: 0 11 | main_training_function: main 12 | megatron_lm_config: {} 13 | mixed_precision: 'bf16' 14 | machine_rank: 0 15 | num_machines: 2 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | use_cpu: false -------------------------------------------------------------------------------- /init_env.sh: -------------------------------------------------------------------------------- 1 | # 这是容器的初始化文件,此脚本会在容器启动后运行,可以在此写上常用包的安装脚本,例如:pip install torch 2 | #pip install deepspeed==0.8.3 3 | #pip install transformers==4.30.0 4 | #pip install accelerate==0.20.3 5 | #pip install /mnt/user/qumu/libs/peft-662ebe593e5d4a2d64a4ee0a0c61c807f7a62617 6 | #pip install BitsAndBytes==0.39.0 7 | #pip install xformers 8 | #pip install ujson 9 | #pip install jsonlines 10 | 11 | pip install SentencePiece==0.1.99 -i https://pypi.antfin-inc.com/simple/ 12 | pip install alps==2.3.1.8 -i https://pypi.antfin-inc.com/simple/ 13 | pip install deepspeed==0.9.5 -i https://pypi.antfin-inc.com/simple/ 14 | pip install accelerate==0.23.0 -i https://pypi.antfin-inc.com/simple/ 15 | pip install transformers==4.32.0 -i https://pypi.antfin-inc.com/simple/ 16 | pip install peft==0.5.0 -i https://pypi.antfin-inc.com/simple/ 17 | pip install tiktoken==0.5.1 -i https://pypi.antfin-inc.com/simple/ 18 | pip install transformers_stream_generator==0.0.4 -i https://pypi.antfin-inc.com/simple/ -------------------------------------------------------------------------------- /llava/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/.DS_Store -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | from .model import LlavaQWenForCausalLM 3 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | REF_START_TOKEN = "" 14 | REF_END_TOKEN = "" 15 | BOX_START_TOKEN = "" 16 | BOX_END_TOKEN = "" 17 | -------------------------------------------------------------------------------- /llava/eval/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/eval/.DS_Store -------------------------------------------------------------------------------- /llava/eval/GPT_eval_llava_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | from Crypto.Cipher import AES 9 | from binascii import b2a_hex, a2b_hex 10 | import requests 11 | import ast 12 | 13 | def aes_encrypt(data, key): 14 | """aes加密函数,如果data不是16的倍数【加密文本data必须为16的倍数!】,那就补足为16的倍数 15 | :param key: 16 | :param data: 17 | """ 18 | iv = "1234567890123456" 19 | cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8')) # 设置AES加密模式 此处设置为CBC模式 20 | block_size = AES.block_size 21 | 22 | # 判断data是不是16的倍数,如果不是用b'\0'补足 23 | if len(data) % block_size != 0: 24 | add = block_size - (len(data) % block_size) 25 | else: 26 | add = 0 27 | data = data.encode('utf-8') + b'\0' * add 28 | encrypted = cipher.encrypt(data) # aes加密 29 | result = b2a_hex(encrypted) # b2a_hex encode 将二进制转换成16进制 30 | return result.decode('utf-8') 31 | 32 | def aes_decode(data, key): 33 | """aes解密 34 | :param key: 35 | :param data: 36 | """ 37 | iv = '1234567890123456' 38 | cipher = AES.new(key.encode('utf-8'), AES.MODE_CBC, iv.encode('utf-8')) 39 | result2 = a2b_hex(data) # 十六进制还原成二进制 40 | decrypted = cipher.decrypt(result2) 41 | return decrypted.rstrip(b'\0') # 解密完成后将加密时添加的多余字符'\0'删除 42 | 43 | NUM_SECONDS_TO_SLEEP = 0.5 44 | 45 | 46 | def get_eval(content: str, max_tokens: int): 47 | serviceName = "your service name" 48 | visitDomain = "your domain" 49 | visitBiz = "your visit biz" 50 | visitBizLine = "your bizline" 51 | api_key = "your api key" 52 | key = "your key" 53 | param = { 54 | "serviceName": serviceName, 55 | "visitDomain": visitDomain, 56 | "visitBiz": visitBiz, 57 | "visitBizLine": visitBizLine, 58 | "cacheInterval": -1, 59 | "queryConditions": { 60 | "model": "gpt-3.5-turbo-16k", 61 | "api_key": api_key, 62 | 63 | 64 | "messages": [{"role": "user", "content": content}] 65 | } 66 | } 67 | url = 'your url' 68 | data = json.dumps(param) % url.encode('utf8') 69 | key = key # 密钥 70 | str = aes_encrypt(data, key) 71 | post_data = { 72 | "encryptedParam": str 73 | } 74 | headers = { 75 | 'Content-Type': 'application/json' 76 | } 77 | while True: 78 | try: 79 | response = requests.post(url, data=json.dumps(post_data), headers=headers) 80 | x = response.json()["data"]["values"]["data"] 81 | ast_str = ast.literal_eval("'" + x + "'") 82 | 83 | js = ast_str.replace('"', '"') 84 | js = js.replace("'", "'") 85 | data = json.loads(js) 86 | 87 | ret = data["choices"][0]["message"]["content"] 88 | break 89 | except openai.error.RateLimitError: 90 | pass 91 | except Exception as e: 92 | print(e) 93 | time.sleep(NUM_SECONDS_TO_SLEEP) 94 | return ret 95 | 96 | 97 | def parse_score(review): 98 | try: 99 | score_pair = review.split('\n')[0] 100 | score_pair = score_pair.replace(', ', " ").replace(',', ' ') 101 | sp = score_pair.split(' ') 102 | if len(sp) == 2: 103 | return [float(sp[0]), float(sp[1])] 104 | else: 105 | print('error', review) 106 | return [-1, -1] 107 | except Exception as e: 108 | print(e) 109 | print('error', review) 110 | return [-1, -1] 111 | 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 115 | parser.add_argument('-q', '--question') 116 | parser.add_argument('-c', '--context') 117 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 118 | parser.add_argument('-o', '--output') 119 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 120 | args = parser.parse_args() 121 | 122 | f_q = open(os.path.expanduser(args.question)) 123 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 124 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 125 | 126 | if os.path.isfile(os.path.expanduser(args.output)): 127 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 128 | else: 129 | cur_reviews = [] 130 | 131 | review_file = open(f'{args.output}', 'a') 132 | 133 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 134 | image_to_context = {context['image']: context for context in context_list} 135 | 136 | handles = [] 137 | idx = 0 138 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 139 | ques = json.loads(ques_js) 140 | ans1 = json.loads(ans1_js) 141 | ans2 = json.loads(ans2_js) 142 | 143 | inst = image_to_context[ques['image']] 144 | 145 | if isinstance(inst['caption'], list): 146 | cap_str = '\n'.join(inst['caption']) 147 | else: 148 | cap_str = inst['caption'] 149 | 150 | 151 | content = (f'[Context]\n{cap_str}\n\n' 152 | f'[Question]\n{ques["text"]}\n\n' 153 | f'[assistant 1]\n{ans1["answer"]}\n\n[End of assistant 1]\n\n' 154 | f'[assistant 2]\n{ans2["text"]}\n\n[End of assistant 2]\n\n' 155 | f'[System]\nYour are a judge to judge the 2 answers given to you. Please rate each answer between score 0 and 100, and use comma to separate the 2 scores.\n\n') 156 | cur_js = { 157 | 'id': idx+1, 158 | 'question_id': ques['question_id'], 159 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 160 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 161 | 'category': "LLAVA" 162 | } 163 | if idx >= len(cur_reviews): 164 | review = get_eval(content, args.max_tokens) 165 | scores = parse_score(review) 166 | cur_js['content'] = review 167 | cur_js['tuple'] = scores 168 | review_file.write(json.dumps(cur_js) + '\n') 169 | review_file.flush() 170 | else: 171 | print(f'Skipping {idx} as we already have it.') 172 | idx += 1 173 | print(idx) 174 | review_file.close() 175 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | predictions = [json.loads(line) for line in open(args.result_file)] 45 | predictions = {pred['question_id']: pred for pred in predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | results = {'correct': [], 'incorrect': []} 49 | sqa_results = {} 50 | sqa_results['acc'] = None 51 | sqa_results['correct'] = None 52 | sqa_results['count'] = None 53 | sqa_results['results'] = {} 54 | sqa_results['outputs'] = {} 55 | 56 | for prob_id, prob in split_problems.items(): 57 | if prob_id not in predictions: 58 | continue 59 | pred = predictions[prob_id] 60 | pred_text = pred['text'] 61 | 62 | pattern = re.compile(r'The answer is ([A-Z]).') 63 | res = pattern.findall(pred_text) 64 | if len(res) == 1: 65 | answer = res[0] # 'A', 'B', ... 66 | else: 67 | answer = "FAILED" 68 | 69 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 70 | 71 | analysis = { 72 | 'question_id': prob_id, 73 | 'parsed_ans': answer, 74 | 'ground_truth': args.options[prob['answer']], 75 | 'question': pred['prompt'], 76 | 'pred': pred_text, 77 | 'is_multimodal': '' in pred['prompt'], 78 | } 79 | 80 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 81 | sqa_results['outputs'][prob_id] = pred_text 82 | 83 | if pred_idx == prob['answer']: 84 | results['correct'].append(analysis) 85 | else: 86 | results['incorrect'].append(analysis) 87 | 88 | correct = len(results['correct']) 89 | total = len(results['correct']) + len(results['incorrect']) 90 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 91 | 92 | sqa_results['acc'] = correct / total * 100 93 | sqa_results['correct'] = correct 94 | sqa_results['count'] = total 95 | 96 | with open(args.output_file, 'w') as f: 97 | json.dump(results, f, indent=2) 98 | with open(args.output_result, 'w') as f: 99 | json.dump(sqa_results, f, indent=2) 100 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/llava_bench_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import os.path as osp 5 | import json 6 | from tqdm import tqdm 7 | import threading 8 | import pandas as pd 9 | import base64 10 | import io 11 | import random 12 | import shortuuid 13 | 14 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 15 | from llava.conversation import conv_templates, SeparatorStyle 16 | from llava.model.builder import load_pretrained_model, load_pretrained_model_custom_proj, load_mixed_pretrained_model 17 | from llava.utils import disable_torch_init 18 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 19 | 20 | from PIL import Image 21 | import math 22 | 23 | 24 | def eval_model(args, questions, start, end, ans_file): 25 | # Model 26 | disable_torch_init() 27 | model_path = os.path.expanduser(args.model_path) 28 | #model_name = get_model_name_from_path(model_path) + "-lora" 29 | model_name = "yuque_qwen-7b-lora" 30 | #tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 31 | #tokenizer, model, image_processor, context_len = load_pretrained_model_custom_proj(model_path, args.model_base, model_name, args.mm_projector) 32 | tokenizer, model, image_processor, context_len = load_mixed_pretrained_model(model_path, args.model_base, model_name, args.vision_tower, args.mm_projector_type, args.mm_projector) 33 | tokenizer.pad_token_id = tokenizer.eod_id 34 | model = model.cuda() 35 | 36 | ans_fp = open(ans_file, "w") 37 | 38 | for i, question in enumerate(questions): 39 | image_fn = osp.join(args.image_folder, question['image']) 40 | image = Image.open(image_fn).convert('RGB') 41 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 42 | 43 | im_start = torch.tensor(tokenizer.im_start_id) ##每次对话起始符,无论用户还是机器 44 | im_end = torch.tensor(tokenizer.im_end_id) ##每次对话起始符,无论用户还是机器 45 | nl_tokens = torch.tensor(tokenizer('\n').input_ids) 46 | _system = torch.tensor(tokenizer('system').input_ids) ##全样本就一个的system 47 | _user = torch.tensor(tokenizer('user').input_ids) 48 | _assistant = torch.tensor(tokenizer('assistant').input_ids) 49 | 50 | inputs = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" 51 | prompt = question['text'] 52 | prompt = "<|im_start|>user\n" + "Picture 1:\n" + prompt + "<|im_end|>\n" + "<|im_start|>assistant\n" 53 | inputs += prompt 54 | 55 | tokens = tokenizer( 56 | inputs, 57 | max_length=tokenizer.model_max_length, 58 | padding=True, 59 | truncation=True, 60 | return_tensors="pt", 61 | ) 62 | input_ids = tokens.input_ids.cuda() 63 | 64 | #stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 65 | stop_str = tokenizer.pad_token 66 | keywords = [stop_str] 67 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 68 | 69 | with torch.inference_mode(): 70 | output_ids = model.generate( 71 | input_ids, 72 | images=image_tensor.unsqueeze(0).cuda(), 73 | do_sample=True, 74 | temperature=0.2, 75 | top_p=0.3, 76 | top_k=0, 77 | #num_beams=1, 78 | # no_repeat_ngram_size=3, 79 | max_new_tokens=2048, 80 | return_dict_in_generate=False, 81 | use_cache=True) 82 | 83 | input_token_len = input_ids.shape[1] 84 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 85 | if n_diff_input_output > 0: 86 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 87 | output_text = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 88 | 89 | output_text = output_text.replace("<|im_end|>", "").replace("<|im_start|>", "").replace("\n", "") 90 | #import pdb; pdb.set_trace() 91 | 92 | out_j_dict = {} 93 | 94 | for k in question: 95 | out_j_dict[k] = question[k] 96 | out_j_dict['answer'] = output_text 97 | out_j_dict_str = json.dumps(out_j_dict) + "\n" 98 | ans_fp.write(out_j_dict_str) 99 | ans_fp.close() 100 | 101 | 102 | 103 | 104 | if __name__ == "__main__": 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 107 | parser.add_argument("--model-base", type=str, default=None) 108 | parser.add_argument("--image-folder", type=str, default="") 109 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 110 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 111 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 112 | parser.add_argument("--num-chunks", type=int, default=1) 113 | parser.add_argument("--chunk-idx", type=int, default=0) 114 | parser.add_argument("--temperature", type=float, default=0.2) 115 | parser.add_argument("--top_p", type=float, default=None) 116 | parser.add_argument("--num_beams", type=int, default=1) 117 | parser.add_argument("--mm-projector", type=str, default=None) 118 | parser.add_argument("--mm-projector-type", type=str, default=None) 119 | parser.add_argument("--vision-tower", type=str, default=None) 120 | args = parser.parse_args() 121 | 122 | thread_num = 1 123 | 124 | # questions file 125 | with open(os.path.expanduser(args.question_file), "r") as f: 126 | questions = [json.loads(ln) for ln in f] 127 | 128 | # answers file 129 | answers_file = os.path.expanduser(args.answers_file) 130 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 131 | ans_file = answers_file 132 | 133 | eval_model(args, questions, 0, len(questions), ans_file) 134 | #eval_model(args, questions, 0, 20, ans_file) -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if 'llama-2' in model_name.lower(): 40 | conv_mode = "llava_llama_2" 41 | elif "v1" in model_name.lower(): 42 | conv_mode = "llava_v1" 43 | elif "mpt" in model_name.lower(): 44 | conv_mode = "mpt" 45 | else: 46 | conv_mode = "llava_v0" 47 | 48 | if args.conv_mode is not None and conv_mode != args.conv_mode: 49 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 50 | else: 51 | args.conv_mode = conv_mode 52 | 53 | conv = conv_templates[args.conv_mode].copy() 54 | conv.append_message(conv.roles[0], qs) 55 | conv.append_message(conv.roles[1], None) 56 | prompt = conv.get_prompt() 57 | 58 | image = load_image(args.image_file) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 60 | 61 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 62 | 63 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 64 | keywords = [stop_str] 65 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 66 | 67 | with torch.inference_mode(): 68 | output_ids = model.generate( 69 | input_ids, 70 | images=image_tensor, 71 | do_sample=True, 72 | temperature=0.2, 73 | max_new_tokens=1024, 74 | use_cache=True, 75 | stopping_criteria=[stopping_criteria]) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | print(outputs) 87 | 88 | if __name__ == "__main__": 89 | parser = argparse.ArgumentParser() 90 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 91 | parser.add_argument("--model-base", type=str, default=None) 92 | parser.add_argument("--image-file", type=str, required=True) 93 | parser.add_argument("--query", type=str, required=True) 94 | parser.add_argument("--conv-mode", type=str, default=None) 95 | args = parser.parse_args() 96 | 97 | eval_model(args) 98 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-f', '--files', nargs='*', default=None) 13 | parser.add_argument('-i', '--ignore', nargs='*', default=None) 14 | return parser.parse_args() 15 | 16 | 17 | if __name__ == '__main__': 18 | args = parse_args() 19 | 20 | if args.ignore is not None: 21 | args.ignore = [int(x) for x in args.ignore] 22 | 23 | if args.files is not None and len(args.files) > 0: 24 | review_files = args.files 25 | else: 26 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] 27 | 28 | for review_file in sorted(review_files): 29 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 30 | scores = defaultdict(list) 31 | print(config) 32 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 33 | for review_str in f: 34 | review = json.loads(review_str) 35 | if args.ignore is not None and review['question_id'] in args.ignore: 36 | continue 37 | if 'category' in review: 38 | scores[review['category']].append(review['tuple']) 39 | scores['all'].append(review['tuple']) 40 | else: 41 | if 'tuple' in review: 42 | scores['all'].append(review['tuple']) 43 | else: 44 | scores['all'].append(review['score']) 45 | for k, v in sorted(scores.items()): 46 | stats = np.asarray(v).mean(0).tolist() 47 | stats = [round(x, 3) for x in stats] 48 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 49 | print(k, round(stats[1]/stats[0]*100, 1)) 50 | print('=================================') 51 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/merge_pretrain_cross_attn_to_qwenvl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 4 | from transformers.generation import GenerationConfig 5 | import torch 6 | import json 7 | import os 8 | import os.path as osp 9 | import shortuuid 10 | import numpy as np 11 | from transformers import Trainer 12 | 13 | import shortuuid 14 | 15 | model_base_path = "/mnt/project/LLAVA/Qwen-VL-Chat/" 16 | cross_attn_path = "/mnt/project/LLAVA/pretrain_weights/qwen-vl-7b-yuque-box-frame-1212/checkpoint-13000/mm_projector.bin" 17 | 18 | tokenizer = AutoTokenizer.from_pretrained(model_base_path, use_fast=False, trust_remote_code=True) 19 | model = AutoModelForCausalLM.from_pretrained(model_base_path, low_cpu_mem_usage=True, trust_remote_code=True) 20 | cross_attn_state_dict = torch.load(cross_attn_path) 21 | 22 | visual_state_dict = model.transformer.visual.state_dict() 23 | 24 | def match_k(visual_state_dict, cross_attn_state_dict): 25 | print("number of parameters to change", len(cross_attn_state_dict)) 26 | cnt = 0 27 | for k in cross_attn_state_dict: 28 | real_k = ".".join(k.split(".")[2:]) 29 | if real_k in visual_state_dict: 30 | visual_state_dict[real_k] = cross_attn_state_dict[k] 31 | cnt += 1 32 | print("number of parameters changed", cnt) 33 | return visual_state_dict 34 | 35 | visual_state_dict = match_k(visual_state_dict, cross_attn_state_dict) 36 | model.transformer.visual.load_state_dict(visual_state_dict) 37 | 38 | output_dir = "/mnt/project/LLAVA/pretrain_weights/merged/qwen-vl-7b-yuque-box-frame-1212/" 39 | trainer = Trainer(model=model, tokenizer=tokenizer) 40 | trainer._save(output_dir) 41 | 42 | -------------------------------------------------------------------------------- /llava/merge_pretrain_weights_to_qwenvl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig 4 | from transformers.generation import GenerationConfig 5 | import torch 6 | import json 7 | import os 8 | import os.path as osp 9 | import shortuuid 10 | import numpy as np 11 | from transformers import Trainer 12 | import shortuuid 13 | import argparse 14 | from llava.model import * 15 | from dataclasses import dataclass, field 16 | from typing import Dict, Optional, Sequence, List 17 | 18 | 19 | @dataclass 20 | class VisionArguments: 21 | vision_tower: Optional[str] = field(default="/mnt/project/LLAVA/Qwen-VL-visual/") 22 | mm_vision_select_layer: Optional[int] = field(default=-2) 23 | mm_vision_select_feature: Optional[str] = field(default="patch") 24 | mm_projector_type: Optional[str] = field(default="linear") 25 | pretrain_mm_mlp_adapter: Optional[str] = field(default="") 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument("--LLM-path", type=str, default=None) 31 | parser.add_argument("--mm-projector-type", type=str, default='linear') 32 | parser.add_argument("--mm-projector", type=str, default=None) 33 | parser.add_argument("--vision-tower", type=str, default=None) 34 | parser.add_argument("--output-path", type=str, default="./Qwen-VL") 35 | 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | if __name__ == "__main__": 41 | args = parse_args() 42 | model_path = args.LLM_path 43 | projector_type = args.mm_projector_type 44 | projector_path = args.mm_projector 45 | vision_tower_path = args.vision_tower 46 | 47 | vision_args = VisionArguments() 48 | vision_args.mm_projector_type = projector_type 49 | vision_args.vision_tower = vision_tower_path 50 | vision_args.pretrain_mm_mlp_adapter = projector_path 51 | 52 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 53 | model = LlavaQWenForCausalLM.from_pretrained( 54 | model_path, 55 | config=cfg_pretrained, 56 | ) 57 | model.get_model().initialize_vision_modules(vision_args) 58 | vision_tower = model.get_vision_tower() 59 | if not vision_tower.is_loaded: 60 | vision_tower.load_model() 61 | image_processor = vision_tower.image_processor 62 | 63 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, use_fast=True) 64 | 65 | output_dir = args.output_path 66 | trainer = Trainer(model=model, tokenizer=tokenizer) 67 | trainer._save(output_dir) 68 | 69 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def process_images(images, image_processor, model_cfg): 15 | return image_processor(images, return_tensors='pt')['pixel_values'] 16 | 17 | 18 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 19 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 20 | 21 | def insert_separator(X, sep): 22 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 23 | 24 | input_ids = [] 25 | offset = 0 26 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 27 | offset = 1 28 | input_ids.append(prompt_chunks[0][0]) 29 | 30 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 31 | input_ids.extend(x[offset:]) 32 | 33 | if return_tensors is not None: 34 | if return_tensors == 'pt': 35 | return torch.tensor(input_ids, dtype=torch.long) 36 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 37 | return input_ids 38 | 39 | 40 | def get_model_name_from_path(model_path): 41 | model_path = model_path.strip("/") 42 | model_paths = model_path.split("/") 43 | if model_paths[-1].startswith('checkpoint-'): 44 | return model_paths[-2] + "_" + model_paths[-1] 45 | else: 46 | return model_paths[-1] 47 | 48 | 49 | 50 | 51 | class KeywordsStoppingCriteria(StoppingCriteria): 52 | def __init__(self, keywords, tokenizer, input_ids): 53 | self.keywords = keywords 54 | self.keyword_ids = [] 55 | for keyword in keywords: 56 | cur_keyword_ids = tokenizer(keyword).input_ids 57 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 58 | cur_keyword_ids = cur_keyword_ids[1:] 59 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 60 | self.tokenizer = tokenizer 61 | self.start_len = input_ids.shape[1] 62 | 63 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 64 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 65 | offset = min(output_ids.shape[1] - self.start_len, 3) 66 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 67 | for keyword_id in self.keyword_ids: 68 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 69 | return True 70 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 71 | for keyword in self.keywords: 72 | if keyword in outputs: 73 | return True 74 | return False 75 | -------------------------------------------------------------------------------- /llava/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/model/.DS_Store -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_qwen import LlavaQWenForCausalLM 3 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 4 | from .language_model.qwen.modeling_qwen import QWenLMHeadModel 5 | from .language_model.qwen.configuration_qwen import QWenConfig 6 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | from dataclasses import dataclass 22 | from transformers.utils import ModelOutput 23 | 24 | from transformers import AutoConfig, AutoModelForCausalLM, \ 25 | LlamaConfig, LlamaModel, LlamaForCausalLM 26 | 27 | 28 | from transformers.modeling_outputs import CausalLMOutputWithPast 29 | 30 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 31 | 32 | @dataclass 33 | class LlavaCausalLMOutputWithPast(ModelOutput): 34 | loss: Optional[torch.FloatTensor] = None 35 | logits: torch.FloatTensor = None 36 | labels: torch.LongTensor = None 37 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None 38 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 39 | attentions: Optional[Tuple[torch.FloatTensor]] = None 40 | 41 | class LlavaConfig(LlamaConfig): 42 | model_type = "llava" 43 | 44 | 45 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 46 | config_class = LlavaConfig 47 | 48 | def __init__(self, config: LlamaConfig): 49 | super(LlavaLlamaModel, self).__init__(config) 50 | 51 | 52 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 53 | config_class = LlavaConfig 54 | 55 | def __init__(self, config): 56 | super(LlamaForCausalLM, self).__init__(config) 57 | self.model = LlavaLlamaModel(config) 58 | 59 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 60 | 61 | # Initialize weights and apply final processing 62 | self.post_init() 63 | 64 | def get_model(self): 65 | return self.model 66 | 67 | def forward( 68 | self, 69 | input_ids: torch.LongTensor = None, 70 | attention_mask: Optional[torch.Tensor] = None, 71 | past_key_values: Optional[List[torch.FloatTensor]] = None, 72 | inputs_embeds: Optional[torch.FloatTensor] = None, 73 | labels: Optional[torch.LongTensor] = None, 74 | use_cache: Optional[bool] = None, 75 | output_attentions: Optional[bool] = None, 76 | output_hidden_states: Optional[bool] = None, 77 | images: Optional[torch.FloatTensor] = None, 78 | return_dict: Optional[bool] = None, 79 | num_dataset: Optional[int] = None, 80 | ) -> Union[Tuple, CausalLMOutputWithPast, LlavaCausalLMOutputWithPast]: 81 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 82 | output_hidden_states = ( 83 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 84 | ) 85 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 86 | 87 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, num_dataset) 88 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 89 | outputs = self.model( 90 | input_ids=input_ids, 91 | attention_mask=attention_mask, 92 | past_key_values=past_key_values, 93 | inputs_embeds=inputs_embeds, 94 | use_cache=use_cache, 95 | output_attentions=output_attentions, 96 | output_hidden_states=output_hidden_states, 97 | return_dict=return_dict 98 | ) 99 | 100 | hidden_states = outputs[0] 101 | logits = self.lm_head(hidden_states) 102 | 103 | loss = None 104 | if labels is not None: 105 | # Shift so that tokens < n predict n 106 | shift_logits = logits[..., :-1, :].contiguous() 107 | shift_labels = labels[..., 1:].contiguous() 108 | # Flatten the tokens 109 | loss_fct = CrossEntropyLoss() 110 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 111 | shift_labels = shift_labels.view(-1) 112 | # Enable model/pipeline parallelism 113 | shift_labels = shift_labels.to(shift_logits.device) 114 | loss = loss_fct(shift_logits, shift_labels) 115 | 116 | if not return_dict: 117 | output = (logits,) + outputs[1:] 118 | return (loss,) + output if loss is not None else output 119 | return LlavaCausalLMOutputWithPast( 120 | loss=loss, 121 | logits=logits, 122 | labels=labels, 123 | past_key_values=outputs.past_key_values, 124 | hidden_states=outputs.hidden_states, 125 | attentions=outputs.attentions, 126 | ) 127 | 128 | def prepare_inputs_for_generation( 129 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 130 | ): 131 | if past_key_values: 132 | input_ids = input_ids[:, -1:] 133 | 134 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 135 | if inputs_embeds is not None and past_key_values is None: 136 | model_inputs = {"inputs_embeds": inputs_embeds} 137 | else: 138 | model_inputs = {"input_ids": input_ids} 139 | 140 | model_inputs.update( 141 | { 142 | "past_key_values": past_key_values, 143 | "use_cache": kwargs.get("use_cache"), 144 | "attention_mask": attention_mask, 145 | "images": kwargs.get("images", None), 146 | } 147 | ) 148 | return model_inputs 149 | 150 | AutoConfig.register("llava", LlavaConfig) 151 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 152 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/language_model/qwen/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | vocab_size=151936, 16 | hidden_size=4096, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | emb_dropout_prob=0.0, 20 | attn_dropout_prob=0.0, 21 | layer_norm_epsilon=1e-6, 22 | initializer_range=0.02, 23 | max_position_embeddings=8192, 24 | scale_attn_weights=True, 25 | use_cache=True, 26 | bf16=False, 27 | fp16=False, 28 | fp32=False, 29 | kv_channels=128, 30 | rotary_pct=1.0, 31 | rotary_emb_base=10000, 32 | use_dynamic_ntk=True, 33 | use_logn_attn=True, 34 | use_flash_attn="auto", 35 | intermediate_size=22016, 36 | no_bias=True, 37 | tie_word_embeddings=False, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.hidden_size = hidden_size 42 | self.intermediate_size = intermediate_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.num_attention_heads = num_attention_heads 45 | self.emb_dropout_prob = emb_dropout_prob 46 | self.attn_dropout_prob = attn_dropout_prob 47 | self.layer_norm_epsilon = layer_norm_epsilon 48 | self.initializer_range = initializer_range 49 | self.scale_attn_weights = scale_attn_weights 50 | self.use_cache = use_cache 51 | self.max_position_embeddings = max_position_embeddings 52 | self.bf16 = bf16 53 | self.fp16 = fp16 54 | self.fp32 = fp32 55 | self.kv_channels = kv_channels 56 | self.rotary_pct = rotary_pct 57 | self.rotary_emb_base = rotary_emb_base 58 | self.use_dynamic_ntk = use_dynamic_ntk 59 | self.use_logn_attn = use_logn_attn 60 | self.use_flash_attn = use_flash_attn 61 | self.no_bias = no_bias 62 | super().__init__( 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs 65 | ) 66 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, ChineseCLIPVisionTower, QWenCLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if "cn_clip" in vision_tower: 9 | return ChineseCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | elif is_absolute_path_exists and "Qwen" in vision_tower: 11 | if 'delay_load' in kwargs: 12 | kwargs.pop('delay_load') 13 | return QWenCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | elif is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 16 | 17 | raise ValueError(f'Unknown vision tower: {vision_tower}') 18 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | from functools import partial 5 | from ..multimodal_encoder.visual import Resampler 6 | import math 7 | 8 | 9 | class IdentityMap(nn.Module): 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | 16 | @property 17 | def config(self): 18 | return {"mm_projector_type": 'identity'} 19 | 20 | 21 | class SimpleResBlock(nn.Module): 22 | def __init__(self, channels): 23 | super().__init__() 24 | self.pre_norm = nn.LayerNorm(channels) 25 | 26 | self.proj = nn.Sequential( 27 | nn.Linear(channels, channels), 28 | nn.GELU(), 29 | nn.Linear(channels, channels) 30 | ) 31 | def forward(self, x): 32 | x = self.pre_norm(x) 33 | return x + self.proj(x) 34 | 35 | 36 | class VLCrossAttention(nn.Module): 37 | def __init__(self, config, vision_tower): 38 | super().__init__() 39 | n_queries = 256 40 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 41 | self.attn_pool = Resampler( 42 | grid_size=int(math.sqrt(n_queries)), 43 | embed_dim=config.hidden_size, 44 | num_heads=config.hidden_size // 128, 45 | kv_dim=vision_tower.hidden_size, 46 | norm_layer=norm_layer, 47 | ) 48 | self.ln_post = norm_layer(config.hidden_size) 49 | self.proj = nn.Parameter((config.hidden_size** -0.5) * torch.randn(config.hidden_size, config.hidden_size)) 50 | 51 | def forward(self, x): 52 | x = self.attn_pool(x) 53 | x = self.ln_post(x) 54 | x = x @ self.proj 55 | 56 | return x 57 | 58 | 59 | def build_vision_projector(config, delay_load=False, vision_tower=None, **kwargs): 60 | projector_type = getattr(config, 'mm_projector_type', 'linear') 61 | print("PROJECTOR TYPE: ", projector_type) 62 | 63 | if projector_type == 'linear': 64 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 65 | 66 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 67 | if mlp_gelu_match: 68 | mlp_depth = int(mlp_gelu_match.group(1)) 69 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 70 | for _ in range(1, mlp_depth): 71 | modules.append(nn.GELU()) 72 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 73 | return nn.Sequential(*modules) 74 | elif "cross_attn" in projector_type: 75 | vl_cross_attn = VLCrossAttention(config, vision_tower) 76 | return vl_cross_attn 77 | 78 | 79 | if projector_type == 'identity': 80 | return IdentityMap() 81 | 82 | raise ValueError(f'Unknown projector type: {projector_type}') -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/.DS_Store -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/assets/android-dsl-mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "opening-tag": "{", 3 | "closing-tag": "}", 4 | "body": "\n\n {}\n\n", 5 | "stack": "\n \n {}\n \n", 6 | "row": "\n{}\n", 7 | "label": "\n", 8 | "btn": "", 12 | "footer": "\n \n \n \n \n {}\n \n", 13 | "btn-search": "", 14 | "btn-contact": "", 15 | "btn-download": "", 16 | "btn-more": "" 17 | } -------------------------------------------------------------------------------- /llava/serve/assets/web-dsl-mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "opening-tag": "{", 3 | "closing-tag": "}", 4 | "body": "\n
\n \n \n \n\n\n Scaffold\n
\n \n
\n {}\n
\n

© Tony Beltramelli 2017

\n
\n
\n \n \n \n\n", 5 | "header": "
\n \n
\n", 6 | "btn-active": "
  • []
  • \n", 7 | "btn-inactive": "
  • []
  • \n", 8 | "row": "
    {}
    \n", 9 | "single": "
    \n{}\n
    \n", 10 | "double": "
    \n{}\n
    \n", 11 | "quadruple": "
    \n{}\n
    \n", 12 | "btn-green": "[]\n", 13 | "btn-orange": "[]\n", 14 | "btn-red": "[]", 15 | "big-title": "

    []

    ", 16 | "small-title": "

    []

    ", 17 | "text": "

    []

    \n" 18 | } -------------------------------------------------------------------------------- /llava/serve/classes/Compiler.py: -------------------------------------------------------------------------------- 1 | import json 2 | from llava.serve.classes.Node import * 3 | 4 | 5 | class Compiler: 6 | def __init__(self, dsl_mapping_file_path): 7 | with open(dsl_mapping_file_path) as data_file: 8 | self.dsl_mapping = json.load(data_file) 9 | 10 | self.opening_tag = self.dsl_mapping["opening-tag"] 11 | self.closing_tag = self.dsl_mapping["closing-tag"] 12 | self.content_holder = self.opening_tag + self.closing_tag 13 | 14 | self.root = Node("body", None, self.content_holder) 15 | 16 | def compile(self, input_text, output_file_path, rendering_function=None): 17 | current_parent = self.root 18 | input_lns = input_text.split("\n") 19 | 20 | for token in input_lns: 21 | token = token.replace(" ", "").replace("\n", "") 22 | 23 | if token.find(self.opening_tag) != -1: 24 | token = token.replace(self.opening_tag, "") 25 | 26 | element = Node(token, current_parent, self.content_holder) 27 | current_parent.add_child(element) 28 | current_parent = element 29 | elif token.find(self.closing_tag) != -1: 30 | current_parent = current_parent.parent 31 | else: 32 | tokens = token.split(",") 33 | for t in tokens: 34 | element = Node(t, current_parent, self.content_holder) 35 | current_parent.add_child(element) 36 | 37 | output_html = self.root.render(self.dsl_mapping, rendering_function=rendering_function) 38 | 39 | return output_html -------------------------------------------------------------------------------- /llava/serve/classes/Node.py: -------------------------------------------------------------------------------- 1 | class Node: 2 | def __init__(self, key, parent_node, content_holder): 3 | self.key = key 4 | self.parent = parent_node 5 | self.children = [] 6 | self.content_holder = content_holder 7 | 8 | def add_child(self, child): 9 | self.children.append(child) 10 | 11 | def show(self): 12 | print(self.key) 13 | for child in self.children: 14 | child.show() 15 | 16 | def render(self, mapping, rendering_function=None): 17 | content = "" 18 | for child in self.children: 19 | content += child.render(mapping, rendering_function) 20 | 21 | value = mapping[self.key] 22 | if rendering_function is not None: 23 | value = rendering_function(self.key, value) 24 | 25 | if len(self.children) != 0: 26 | value = value.replace(self.content_holder, content) 27 | 28 | return value -------------------------------------------------------------------------------- /llava/serve/classes/Utils.py: -------------------------------------------------------------------------------- 1 | import string 2 | import random 3 | 4 | 5 | class Utils: 6 | @staticmethod 7 | def get_random_text(length_text=10, space_number=1, with_upper_case=True): 8 | results = [] 9 | while len(results) < length_text: 10 | char = random.choice(string.ascii_letters[:26]) 11 | results.append(char) 12 | if with_upper_case: 13 | results[0] = results[0].upper() 14 | 15 | current_spaces = [] 16 | while len(current_spaces) < space_number: 17 | space_pos = random.randint(2, length_text - 3) 18 | if space_pos in current_spaces: 19 | break 20 | results[space_pos] = " " 21 | if with_upper_case: 22 | results[space_pos + 1] = results[space_pos - 1].upper() 23 | 24 | current_spaces.append(space_pos) 25 | 26 | return ''.join(results) 27 | 28 | @staticmethod 29 | def get_ios_id(length=10): 30 | results = [] 31 | 32 | while len(results) < length: 33 | char = random.choice(string.digits + string.ascii_letters) 34 | results.append(char) 35 | 36 | results[3] = "-" 37 | results[6] = "-" 38 | 39 | return ''.join(results) 40 | 41 | @staticmethod 42 | def get_android_id(length=10): 43 | results = [] 44 | 45 | while len(results) < length: 46 | char = random.choice(string.ascii_letters) 47 | results.append(char) 48 | 49 | return ''.join(results) -------------------------------------------------------------------------------- /llava/serve/classes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/classes/__init__.py -------------------------------------------------------------------------------- /llava/serve/classes/__pycache__/Compiler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/classes/__pycache__/Compiler.cpython-38.pyc -------------------------------------------------------------------------------- /llava/serve/classes/__pycache__/Node.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/classes/__pycache__/Node.cpython-38.pyc -------------------------------------------------------------------------------- /llava/serve/classes/__pycache__/Utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/classes/__pycache__/Utils.cpython-38.pyc -------------------------------------------------------------------------------- /llava/serve/classes/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/classes/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http') or image_file.startswith('https'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 56 | 57 | while True: 58 | try: 59 | inp = input(f"{roles[0]}: ") 60 | except EOFError: 61 | inp = "" 62 | if not inp: 63 | print("exit...") 64 | break 65 | 66 | print(f"{roles[1]}: ", end="") 67 | 68 | if image is not None: 69 | # first message 70 | if model.config.mm_use_im_start_end: 71 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 72 | else: 73 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 74 | conv.append_message(conv.roles[0], inp) 75 | image = None 76 | else: 77 | # later messages 78 | conv.append_message(conv.roles[0], inp) 79 | conv.append_message(conv.roles[1], None) 80 | prompt = conv.get_prompt() 81 | 82 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 83 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 84 | keywords = [stop_str] 85 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 86 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 87 | 88 | with torch.inference_mode(): 89 | output_ids = model.generate( 90 | input_ids, 91 | images=image_tensor, 92 | do_sample=True, 93 | temperature=0.2, 94 | max_new_tokens=1024, 95 | streamer=streamer, 96 | use_cache=True, 97 | stopping_criteria=[stopping_criteria]) 98 | 99 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 100 | conv.messages[-1][-1] = outputs 101 | 102 | if args.debug: 103 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 104 | 105 | 106 | if __name__ == "__main__": 107 | parser = argparse.ArgumentParser() 108 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 109 | parser.add_argument("--model-base", type=str, default=None) 110 | parser.add_argument("--image-file", type=str, required=True) 111 | parser.add_argument("--num-gpus", type=int, default=1) 112 | parser.add_argument("--conv-mode", type=str, default=None) 113 | parser.add_argument("--temperature", type=float, default=0.2) 114 | parser.add_argument("--max-new-tokens", type=int, default=512) 115 | parser.add_argument("--load-8bit", action="store_true") 116 | parser.add_argument("--load-4bit", action="store_true") 117 | parser.add_argument("--debug", action="store_true") 118 | args = parser.parse_args() 119 | main(args) 120 | -------------------------------------------------------------------------------- /llava/serve/demo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from __future__ import print_function 3 | 4 | import sys 5 | import argparse 6 | 7 | from os.path import basename 8 | from classes.Utils import * 9 | from classes.Compiler import * 10 | import os.path as osp 11 | 12 | FILL_WITH_RANDOM_TEXT = True 13 | TEXT_PLACE_HOLDER = "[]" 14 | 15 | 16 | def render_content_with_text(key, value): 17 | if FILL_WITH_RANDOM_TEXT: 18 | if key.find("btn") != -1: 19 | value = value.replace(TEXT_PLACE_HOLDER, Utils.get_random_text()) 20 | elif key.find("title") != -1: 21 | value = value.replace(TEXT_PLACE_HOLDER, Utils.get_random_text(length_text=5, space_number=0)) 22 | elif key.find("text") != -1: 23 | value = value.replace(TEXT_PLACE_HOLDER, 24 | Utils.get_random_text(length_text=56, space_number=7, with_upper_case=False)) 25 | return value 26 | 27 | def render_dsls(input_fp, args): 28 | i = 0 29 | for ln in input_fp: 30 | compiler = Compiler(args.dsl_path) 31 | j_dict = json.loads(ln) 32 | input_text = j_dict['prediction'] 33 | input_text = input_text.replace("{", "{\n").replace("}", "\n}\n").replace("\n\n", '\n').rstrip("\n") 34 | 35 | label = j_dict['label'][0] 36 | label = label.rstrip("\n") 37 | 38 | output_file_path = osp.join(args.output_folder, j_dict['fn'].split("/")[-1].replace(".gui", ".html")) 39 | pred_output_html = compiler.compile(input_text, output_file_path, rendering_function=render_content_with_text) 40 | #label_output_html = compiler.compile(label, output_file_path, rendering_function=render_content_with_text) 41 | i += 1 42 | #print(pred_output_html) 43 | 44 | if __name__=="__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--input-file", type=str, default="input.jsonl") 47 | parser.add_argument("--output-folder", type=str, default="output/") 48 | parser.add_argument("--dsl-path", type=str, default="dsl.json") 49 | args = parser.parse_args() 50 | 51 | input_fp = open(args.input_file) 52 | render_dsls(input_fp, args) -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/serve/.dockerenv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/serve/.dockerenv -------------------------------------------------------------------------------- /llava/serve/serve/.pouch.first: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/serve/.pouch.first -------------------------------------------------------------------------------- /llava/serve/serve/.pouch_runc_init: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/serve/serve/.pouch_runc_init -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/train/__pycache__/llama_flash_attn_monkey_patch.cpython-38.pyc -------------------------------------------------------------------------------- /llava/train/__pycache__/llava_trainer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/train/__pycache__/llava_trainer.cpython-38.pyc -------------------------------------------------------------------------------- /llava/train/__pycache__/train.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/llava/train/__pycache__/train.cpython-38.pyc -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 14 | except ImportError: 15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 16 | from flash_attn.bert_padding import unpad_input, pad_input 17 | 18 | 19 | def forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 25 | output_attentions: bool = False, 26 | use_cache: bool = False, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) 49 | # [bsz, q_len, nh, hd] 50 | # [bsz, nh, q_len, hd] 51 | 52 | kv_seq_len = key_states.shape[-2] 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb( 57 | query_states, key_states, cos, sin, position_ids 58 | ) 59 | # [bsz, nh, t, hd] 60 | assert not output_attentions, "output_attentions is not supported" 61 | assert not use_cache, "use_cache is not supported" 62 | 63 | # Flash attention codes from 64 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 65 | 66 | # transform the data into the format required by flash attention 67 | qkv = torch.stack( 68 | [query_states, key_states, value_states], dim=2 69 | ) # [bsz, nh, 3, q_len, hd] 70 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 71 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 72 | # the attention_mask should be the same as the key_padding_mask 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 77 | max_s = q_len 78 | cu_q_lens = torch.arange( 79 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 80 | ) 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 85 | else: 86 | nheads = qkv.shape[-2] 87 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 88 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 89 | x_unpad = rearrange( 90 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 91 | ) 92 | output_unpad = flash_attn_unpadded_qkvpacked_func( 93 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 94 | ) 95 | output = rearrange( 96 | pad_input( 97 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 98 | ), 99 | "b s (h d) -> b s h d", 100 | h=nheads, 101 | ) 102 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 103 | 104 | 105 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 106 | # requires the attention mask to be the same as the key_padding_mask 107 | def _prepare_decoder_attention_mask( 108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 109 | ): 110 | # [bsz, seq_len] 111 | return attention_mask 112 | 113 | 114 | def replace_llama_attn_with_flash_attn(): 115 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 116 | if cuda_major < 8: 117 | logging.warning( 118 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 119 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 120 | ) 121 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 122 | _prepare_decoder_attention_mask 123 | ) 124 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 125 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | #from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | #replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | -------------------------------------------------------------------------------- /playground/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/playground/.DS_Store -------------------------------------------------------------------------------- /playground/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/codefuse-ai/CodeFuse-MFT-VLM/46e9a3c0049275110b96cca58ce78e6ebe05ffe6/playground/data/.DS_Store -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/000_caps.txt: -------------------------------------------------------------------------------- 1 | A man wearing multiple neck ties making a goofy face. 2 | A man in a white shirt wearing very many ties. 3 | a man with ties on poses for a picture 4 | A man wearing multiple ties on his neck. 5 | A young man smiles while wearing several ties. 6 | 7 | tie: [0.574, 0.298, 0.752, 0.704] 8 | tie: [0.464, 0.339, 0.639, 0.789] 9 | tie: [0.349, 0.363, 0.563, 0.732] 10 | tie: [0.259, 0.255, 0.668, 0.805] 11 | person: [0.019, 0.065, 0.962, 0.988] 12 | person: [0.0, 0.24, 0.214, 1.0] 13 | tie: [0.316, 0.778, 0.443, 0.867] 14 | tie: [0.386, 0.707, 0.496, 0.801] 15 | tie: [0.251, 0.354, 0.402, 0.747] 16 | tie: [0.44, 0.318, 0.63, 0.397] 17 | tie: [0.566, 0.33, 0.655, 0.543] 18 | tie: [0.25, 0.409, 0.359, 0.556] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/000_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What is unusual about this photo? 3 | === 4 | Answer: 5 | In the photo, the man is wearing a total of ten ties around his neck. This is considered unusual because it is not common for someone to wear multiple ties at once. In a typical setting, a person would wear only one tie. The man's goofy facial expression also adds to the unconventional and humorous nature of the image. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/001_caps.txt: -------------------------------------------------------------------------------- 1 | A group of people standing outside of a black vehicle with various luggage. 2 | Luggage surrounds a vehicle in an underground parking area 3 | People try to fit all of their luggage in an SUV. 4 | The sport utility vehicle is parked in the public garage, being packed for a trip 5 | Some people with luggage near a van that is transporting it. 6 | 7 | person: [0.681, 0.242, 0.774, 0.694] 8 | person: [0.63, 0.222, 0.686, 0.516] 9 | person: [0.444, 0.233, 0.487, 0.34] 10 | backpack: [0.384, 0.696, 0.485, 0.914] 11 | backpack: [0.755, 0.413, 0.846, 0.692] 12 | suitcase: [0.758, 0.413, 0.845, 0.69] 13 | suitcase: [0.1, 0.497, 0.173, 0.579] 14 | bicycle: [0.282, 0.363, 0.327, 0.442] 15 | car: [0.786, 0.25, 0.848, 0.322] 16 | car: [0.783, 0.27, 0.827, 0.335] 17 | car: [0.86, 0.254, 0.891, 0.3] 18 | car: [0.261, 0.101, 0.787, 0.626] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/001_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What challenges do these people face? 3 | === 4 | Answer: 5 | In the image, a group of people is standing outside a black SUV in a parking area, surrounded by various pieces of luggage, including suitcases and backpacks. They are facing the challenge of fitting all their luggage into the black SUV. There are multiple suitcases and backpacks to be packed, which suggests that the group has a significant amount of belongings to accommodate. They might have to strategize and arrange the luggage efficiently to ensure that everything fits properly into the vehicle. Additionally, they need to consider the comfort of the passengers and visibility while driving, so the placement of the luggage must not obstruct the driver's view or make the passengers uncomfortable during the trip. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/002_caps.txt: -------------------------------------------------------------------------------- 1 | There is a movie theater that displays the show times above the doors. 2 | A red fire hydrant is deep in the snow. 3 | The fire hydrant is in the snow near a recently plowed sidewalk. 4 | This city has had a very hard winter with snow. 5 | A hotel for dogs in the snow in winter. 6 | 7 | fire hydrant: [0.326, 0.612, 0.426, 0.72] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/002_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What challenges might this city face? 3 | === 4 | Answer: 5 | The city faces challenges due to the harsh winter conditions and heavy snowfall. In the image, a red fire hydrant is almost buried deep in the snow, which indicates the significant amount of snow the city has experienced. This can lead to various challenges such as difficulties in transportation, increased risk of accidents, and disruptions to daily life. For example, the recently plowed sidewalk near the fire hydrant shows that the city has to constantly clear snow from roads and sidewalks to maintain access and safety for pedestrians and vehicles. Moreover, emergency services, like firefighters, might face challenges accessing crucial equipment, such as fire hydrants, during emergencies due to the snow accumulation. This highlights the importance of effective snow management strategies and preparedness in such cities to minimize the impact of harsh winter conditions on residents and essential services. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant that can analyze a single image. You receive five sentences, each describing the same image you are observing. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. 2 | 3 | The task is to use the provided caption and bounding box information, create a plausible question about the image, and provide the answer in detail. 4 | 5 | Create complex questions beyond describing the scene. 6 | To answer such questions, one should require first understanding the visual content, then based on the background knowledge or reasoning, either explain why the things are happening that way, or provide guides and help to user's request. Make the question challenging by not including the visual content details in the question so that the user needs to reason about that first. 7 | 8 | Instead of directly mentioning the bounding box coordinates, utilize this data to explain the scene using natural language. Include details like object counts, position of the objects, relative position between the objects. 9 | 10 | When using the information from the caption and coordinates, directly explain the scene, and do not mention that the information source is the caption or the bounding box. Always answer as if you are directly looking at the image. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/000_caps.txt: -------------------------------------------------------------------------------- 1 | There is a movie theater that displays the show times above the doors. 2 | A red fire hydrant is deep in the snow. 3 | The fire hydrant is in the snow near a recently plowed sidewalk. 4 | This city has had a very hard winter with snow. 5 | A hotel for dogs in the snow in winter. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/000_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What color is the fire hydrant in the image? 3 | === 4 | Answer: 5 | The fire hydrant in the image is red. 6 | === 7 | Question: 8 | Is the fire hydrant fully covered by snow or partially visible? 9 | === 10 | Answer: 11 | The fire hydrant is partially visible, with a significant portion of it covered by the deep snow. 12 | === 13 | Question: 14 | What is the condition of the sidewalk near the fire hydrant? 15 | === 16 | Answer: 17 | The sidewalk near the fire hydrant has been recently plowed, so it is relatively clear of snow compared to the surrounding area. 18 | === 19 | Question: 20 | What does the image suggest about the winter season in the city? 21 | === 22 | Answer: 23 | The image suggests that the winter season in the city has been quite harsh and challenging. The significant snow accumulation seen in the image indicates that the city has experienced heavy snowfall, which can lead to various difficulties and disruptions for its residents. 24 | 25 | For instance, the deep snow covering the fire hydrant could potentially hinder its accessibility in case of emergencies, which poses safety concerns. The recently plowed sidewalk implies that the city's maintenance crews have been actively working to keep the walkways clear and safe for pedestrians, but the sheer amount of snowfall might make it difficult to keep up with the continuous snow removal efforts. 26 | 27 | Furthermore, such extreme winter conditions can impact transportation, leading to traffic issues, delays in public transit, and increased risks of accidents due to slippery roads. It can also cause problems for businesses and schools, as people might face challenges commuting to work or attending classes. Additionally, the heavy snow can put extra strain on infrastructure, such as roofs and power lines, increasing the likelihood of structural damage or power outages. 28 | 29 | In conclusion, the image of the red fire hydrant deep in the snow and the recently plowed sidewalk suggest that the city has faced a particularly severe winter season, with substantial snowfall that has likely caused various challenges and disruptions for its residents and infrastructure. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/001_caps.txt: -------------------------------------------------------------------------------- 1 | A man is skiing in the open snow covered hills 2 | A skier is making his way into the snow. 3 | A skier on the foothills of a huge mountain range. 4 | A skier looks at mountains while standing near a trail sign. 5 | a single person skiing by an area with a lot of bushes -------------------------------------------------------------------------------- /playground/data/prompts/conversation/001_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What is the skier doing? 3 | === 4 | Answer: 5 | The skier is skiing in the snow-covered hills, making his way through the snow and enjoying the beautiful mountain scenery. They are also standing near a trail sign, which indicates their location on the mountain. 6 | === 7 | Question: 8 | Are there any other people or skiers visible in the image? 9 | === 10 | Answer: 11 | No, there is only a single person skiing in the image, and they are surrounded by an area with a lot of bushes. 12 | === 13 | Question: 14 | What kind of skiing is the person doing? Cross-country or downhill? 15 | === 16 | Answer: 17 | Based on what I can see in the photo, it is not possible to definitively determine whether the skier is doing cross-country or downhill skiing. However, given the presence of a trail sign and the fact that they are skiing in the foothills, it is more likely that the skier is engaged in cross-country skiing. Downhill skiing typically takes place on more steep slopes and involves ski lifts, whereas cross-country skiing often occurs on relatively flatter terrain with marked trails. 18 | === 19 | Question: 20 | Can you explain cross-country and downhill skiing a bit more? 21 | === 22 | Answer: 23 | Cross-country skiing and downhill skiing are two distinct types of skiing, each with its own techniques, equipment, and terrain preferences. 24 | 25 | 1. Cross-Country Skiing: 26 | Also known as Nordic skiing, cross-country skiing involves traversing relatively flat terrain or gentle slopes using skis and poles. It is a popular recreational and competitive sport that focuses on endurance, balance, and aerobic fitness. Cross-country skiers use their own locomotion to move across snow-covered terrain, rather than relying on ski lifts or other forms of assistance. 27 | There are two main techniques in cross-country skiing: 28 | - Classic: This technique involves skiing in parallel tracks, using a forward-and-backward gliding motion while keeping the skis parallel to each other. The skier uses their poles to propel themselves forward. 29 | - Skate skiing: This technique is similar to ice skating, where the skier pushes off one ski at an angle and glides on the other. It requires a smooth, groomed surface and is typically faster than the classic technique. 30 | Cross-country ski equipment consists of lightweight skis, boots, bindings, and poles. The boots are more flexible than those used in downhill skiing, allowing for better ankle movement and control. 31 | 32 | 2. Downhill Skiing: 33 | Also known as alpine skiing, downhill skiing involves descending slopes at high speeds using skis and poles for balance and control. This sport is more focused on speed, technique, and navigating challenging terrain, including steep slopes, moguls, and even jumps. 34 | Downhill skiing can be further categorized into several disciplines, such as slalom, giant slalom, super-G, and downhill racing. Each discipline has its own set of rules, courses, and ski equipment. 35 | Downhill ski equipment includes heavier and stiffer skis, boots, bindings, and poles than those used in cross-country skiing. The boots are more rigid to provide better support and control during high-speed descents and sharp turns. 36 | 37 | In summary, cross-country skiing is an endurance-based sport that involves traveling across flat or gently sloping terrain, while downhill skiing is focused on speed and technique as skiers navigate steeper slopes and challenging terrain. Both sports require specialized equipment and techniques, but they offer different experiences and challenges to participants. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant, and you are seeing a single image. What you see are provided with five sentences, describing the same image you are looking at. Answer all questions as you are seeing the image. 2 | 3 | Design a conversation between you and a person asking about this photo. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question. 4 | Ask diverse questions and give corresponding answers. 5 | 6 | Include questions asking about the visual content of the image, including the object types, counting the objects, object actions, object locations, relative positions between objects, etc. Only include questions that have definite answers: 7 | (1) one can see the content in the image that the question asks about and can answer confidently; 8 | (2) one can determine confidently from the image that it is not in the image. 9 | Do not ask any question that cannot be answered confidently. 10 | 11 | Also include complex questions that are relevant to the content in the image, for example, asking about background knowledge of the objects in the image, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details. 12 | Provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/000_caps.txt: -------------------------------------------------------------------------------- 1 | A harbor filled with lots of boats next to a building. 2 | A bicycle parked in front of several boats at a dock. 3 | A red bicycle in front of a line of docked white yachts 4 | A bike sits before boats which sit before a long building. 5 | A bicycle is a convenient means of land transportation when you live on a boat. 6 | 7 | bicycle: [0.287, 0.641, 0.507, 0.874] 8 | bicycle: [0.566, 0.667, 0.63, 0.731] 9 | boat: [0.318, 0.579, 0.575, 0.724] 10 | boat: [0.704, 0.607, 0.818, 0.727] 11 | boat: [0.818, 0.601, 0.942, 0.744] 12 | boat: [0.002, 0.53, 0.243, 0.71] 13 | boat: [0.541, 0.611, 0.668, 0.731] 14 | person: [0.778, 0.527, 0.797, 0.57] 15 | cup: [0.708, 0.733, 0.724, 0.758] 16 | boat: [0.236, 0.532, 0.404, 0.64] 17 | boat: [0.81, 0.632, 0.836, 0.676] 18 | boat: [0.957, 0.526, 1.0, 0.752] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/000_conv.txt: -------------------------------------------------------------------------------- 1 | It is a harbor filled with numerous boats of various sizes docked next to a long building. Among the boats, there are a few white yachts lined up, standing out from the rest. There is a red bicycle prominently parked in front of the line of docked boats, serving as a convenient means of land transportation for those living on the boats. Another bicycle can be seen further back in the scene, near the middle of the harbor. 2 | 3 | A person is visible near the right side of the harbor, possibly enjoying the view or attending to their boat. Additionally, there is a cup placed on a surface near the middle of the scene. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/001_caps.txt: -------------------------------------------------------------------------------- 1 | A group of people standing outside of a black vehicle with various luggage. 2 | Luggage surrounds a vehicle in an underground parking area 3 | People try to fit all of their luggage in an SUV. 4 | The sport utility vehicle is parked in the public garage, being packed for a trip 5 | Some people with luggage near a van that is transporting it. 6 | 7 | person: [0.681, 0.242, 0.774, 0.694] 8 | person: [0.63, 0.222, 0.686, 0.516] 9 | person: [0.444, 0.233, 0.487, 0.34] 10 | backpack: [0.384, 0.696, 0.485, 0.914] 11 | backpack: [0.755, 0.413, 0.846, 0.692] 12 | suitcase: [0.758, 0.413, 0.845, 0.69] 13 | suitcase: [0.1, 0.497, 0.173, 0.579] 14 | bicycle: [0.282, 0.363, 0.327, 0.442] 15 | car: [0.786, 0.25, 0.848, 0.322] 16 | car: [0.783, 0.27, 0.827, 0.335] 17 | car: [0.86, 0.254, 0.891, 0.3] 18 | car: [0.261, 0.101, 0.787, 0.626] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/001_conv.txt: -------------------------------------------------------------------------------- 1 | The image is an underground parking area with a black sport utility vehicle (SUV) parked. There are three people in the scene, with one person standing closer to the left side of the vehicle, another person in the middle, and the third person on the right side. They are all working together to pack their luggage into the SUV for a trip. 2 | 3 | In the parking area, various luggage items are scattered around the vehicle. There are two backpacks, one located near the left rear wheel and the other closer to the right side of the vehicle. Additionally, there are two suitcases, one on the right side of the car and another further away near the center of the parking area. A bicycle can also be seen on the left side of the vehicle. 4 | 5 | Other cars are parked around the main SUV, with one car positioned behind it and slightly to the left, another behind and slightly to the right, and the third car further behind on the right side. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/002_caps.txt: -------------------------------------------------------------------------------- 1 | A man holds a Wii-mote above his head while another looks on. 2 | A guy and his friend are playing Nintendo Wii. 3 | A young man is holding a video game remote over his head. 4 | two men standing in a room while one plays with a wii mote 5 | Some guys standing and playing a video game. 6 | 7 | couch: [0.697, 0.759, 0.995, 1.0] 8 | dining table: [0.426, 0.755, 1.0, 0.987] 9 | person: [0.082, 0.252, 0.342, 1.0] 10 | person: [0.399, 0.085, 0.742, 0.982] 11 | remote: [0.477, 0.135, 0.516, 0.187] 12 | sink: [0.016, 0.501, 0.063, 0.52] 13 | potted plant: [0.798, 0.384, 0.888, 0.645] 14 | refrigerator: [0.305, 0.389, 0.414, 0.547] 15 | chair: [0.72, 0.509, 0.858, 0.725] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/002_conv.txt: -------------------------------------------------------------------------------- 1 | The image shows two men standing in a room, engaged in playing a video game on a Nintendo Wii console. One of the men is holding a Wii remote above his head with enthusiasm, while the other man looks on, likely enjoying the friendly competition. 2 | 3 | The room appears to be a living space with a couch located in the background and a dining table nearby. A potted plant can be seen placed close to the couch, and a chair is situated in the middle of the room. The room also features a kitchen area with a sink and a refrigerator visible in the background. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant that can analyze a single image. You receive five sentences, each describing the same image you are observing. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. 2 | 3 | Using the provided caption and bounding box information, describe the scene in a detailed manner. 4 | 5 | Instead of directly mentioning the bounding box coordinates, utilize this data to explain the scene using natural language. Include details like object counts, position of the objects, relative position between the objects. 6 | 7 | When using the information from the caption and coordinates, directly explain the scene, and do not mention that the information source is the caption or the bounding box. Always answer as if you are directly looking at the image. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.0.1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers>=0.12.1", 18 | "torch", "torchvision", "uvicorn", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.9.5", 21 | "peft==0.4.0", 22 | "transformers==4.31.0", 23 | "accelerate==0.21.0", 24 | "bitsandbytes==0.41.0", 25 | "scikit-learn==1.2.2", 26 | "sentencepiece==0.1.99", 27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 28 | "gradio_client==0.2.9" 29 | ] 30 | 31 | [project.urls] 32 | "Homepage" = "https://llava-vl.github.io" 33 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 34 | 35 | [tool.setuptools.packages.find] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 37 | 38 | [tool.wheel] 39 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 40 | -------------------------------------------------------------------------------- /scripts/acc_ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "cpu", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "cpu", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | 53 | "gradient_accumulation_steps": 1, 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 100, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } -------------------------------------------------------------------------------- /scripts/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 18 | --version $PROMPT_VERSION \ 19 | --data_path ./playground/data/llava_instruct_80k.json \ 20 | --image_folder /path/to/coco/train2017 \ 21 | --vision_tower openai/clip-vit-large-patch14 \ 22 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --mm_use_im_patch_token False \ 26 | --bf16 True \ 27 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ 28 | --num_train_epochs 1 \ 29 | --per_device_train_batch_size 16 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 1 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 50000 \ 35 | --save_total_limit 1 \ 36 | --learning_rate 2e-5 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 1 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --report_to wandb 47 | -------------------------------------------------------------------------------- /scripts/finetune_full_schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 18 | --version $PROMPT_VERSION \ 19 | --data_path ./playground/data/llava_instruct_158k.json \ 20 | --image_folder /path/to/coco/train2017 \ 21 | --vision_tower openai/clip-vit-large-patch14 \ 22 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --mm_use_im_patch_token False \ 26 | --bf16 True \ 27 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ 28 | --num_train_epochs 3 \ 29 | --per_device_train_batch_size 16 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 1 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 50000 \ 35 | --save_total_limit 1 \ 36 | --learning_rate 2e-5 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 1 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --report_to wandb 47 | -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --lora_enable True \ 18 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 19 | --version $PROMPT_VERSION \ 20 | --data_path ./playground/data/llava_instruct_80k.json \ 21 | --image_folder /path/to/coco/train2017 \ 22 | --vision_tower openai/clip-vit-large-patch14 \ 23 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 24 | --mm_vision_select_layer -2 \ 25 | --mm_use_im_start_end False \ 26 | --mm_use_im_patch_token False \ 27 | --bf16 True \ 28 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 29 | --num_train_epochs 1 \ 30 | --per_device_train_batch_size 16 \ 31 | --per_device_eval_batch_size 4 \ 32 | --gradient_accumulation_steps 1 \ 33 | --evaluation_strategy "no" \ 34 | --save_strategy "steps" \ 35 | --save_steps 50000 \ 36 | --save_total_limit 1 \ 37 | --learning_rate 2e-5 \ 38 | --weight_decay 0. \ 39 | --warmup_ratio 0.03 \ 40 | --lr_scheduler_type "cosine" \ 41 | --logging_steps 1 \ 42 | --tf32 True \ 43 | --model_max_length 2048 \ 44 | --gradient_checkpointing True \ 45 | --lazy_preprocess True \ 46 | --dataloader_num_workers 4 \ 47 | --report_to wandb 48 | -------------------------------------------------------------------------------- /scripts/finetune_multinode.sh: -------------------------------------------------------------------------------- 1 | accelerate launch \ 2 | --num_machines 4 \ 3 | --num_processes 32 \ 4 | --use_deepspeed \ 5 | --deepspeed_multinode_launcher 'standard' \ 6 | --zero_stage 2 \ 7 | --offload_optimizer_device 'cpu' \ 8 | --offload_param_device 'none' \ 9 | --gradient_accumulation_steps 1 \ 10 | --gradient_clipping 1.0 \ 11 | --zero3_init_flag false \ 12 | --zero3_save_16bit_model false \ 13 | --main_training_function 'main' \ 14 | --mixed_precision 'bf16' \ 15 | --dynamo_backend 'no' \ 16 | --same_network \ 17 | --machine_rank $RANK \ 18 | --main_process_ip $MASTER_ADDR \ 19 | --main_process_port $MASTER_PORT \ 20 | --rdzv_backend 'static' \ 21 | llava/train/train_mem.py \ 22 | --data_path /your/train/data/path \ 23 | --eval_path /your/eval/data/path \ 24 | --is_parquet True \ 25 | --output_dir /path/to/save/weights \ 26 | --image_folder /path/to/store/images/ \ 27 | --image_aspect_ratio None \ 28 | --lora_enable True \ 29 | --lora_r 16 \ 30 | --lora_alpha 64 \ 31 | --model_name_or_path /path/to/model/ \ 32 | --version qwen \ 33 | --vision_tower /path/to/visiontower/ \ 34 | --mm_projector_type cross_attn \ 35 | --pretrain_mm_mlp_adapter /path/to/mmprojector/mm_projector.bin \ 36 | --mm_vision_select_layer -2 \ 37 | --mm_use_im_start_end False \ 38 | --use_im_start_end False \ 39 | --mm_use_im_patch_token False \ 40 | --bf16 True \ 41 | --num_train_epochs 5 \ 42 | --per_device_train_batch_size 4 \ 43 | --per_device_eval_batch_size 2 \ 44 | --gradient_accumulation_steps 1 \ 45 | --evaluation_strategy "steps" \ 46 | --eval_steps 1000 \ 47 | --load_best_model_at_end True \ 48 | --save_strategy "steps" \ 49 | --save_steps 1000 \ 50 | --save_total_limit 10 \ 51 | --freeze_backbone True \ 52 | --learning_rate 2e-5 \ 53 | --weight_decay 0. \ 54 | --warmup_ratio 0.03 \ 55 | --lr_scheduler_type "cosine" \ 56 | --logging_steps 1 \ 57 | --tf32 True \ 58 | --model_max_length 2048 \ 59 | --gradient_checkpointing True \ 60 | --lazy_preprocess True \ 61 | --dataloader_num_workers 4 \ 62 | --report_to none -------------------------------------------------------------------------------- /scripts/finetune_qlora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | ################## VICUNA ################## 6 | # PROMPT_VERSION=v1 7 | # MODEL_VERSION="vicuna-v1-3-7b" 8 | ################## VICUNA ################## 9 | 10 | ################## LLaMA-2 ################## 11 | # PROMPT_VERSION="llava_llama_2" 12 | # MODEL_VERSION="llama-2-7b-chat" 13 | ################## LLaMA-2 ################## 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --lora_enable True \ 18 | --bits 4 \ 19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 20 | --version $PROMPT_VERSION \ 21 | --data_path ./playground/data/llava_instruct_80k.json \ 22 | --image_folder /path/to/coco/train2017 \ 23 | --vision_tower openai/clip-vit-large-patch14 \ 24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --bf16 True \ 29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 30 | --num_train_epochs 1 \ 31 | --per_device_train_batch_size 16 \ 32 | --per_device_eval_batch_size 4 \ 33 | --gradient_accumulation_steps 1 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 50000 \ 37 | --save_total_limit 1 \ 38 | --learning_rate 2e-5 \ 39 | --weight_decay 0. \ 40 | --warmup_ratio 0.03 \ 41 | --lr_scheduler_type "cosine" \ 42 | --logging_steps 1 \ 43 | --tf32 True \ 44 | --model_max_length 2048 \ 45 | --gradient_checkpointing True \ 46 | --lazy_preprocess True \ 47 | --dataloader_num_workers 4 \ 48 | --report_to wandb 49 | -------------------------------------------------------------------------------- /scripts/finetune_yuque_qwen.sh: -------------------------------------------------------------------------------- 1 | deepspeed llava/train/train_mem.py \ 2 | --deepspeed ./scripts/new_ds_config.json \ 3 | --data_path /mnt/project/LLAVA/image2schema_data/sft_data/sharegpt4v_data/ \ 4 | --eval_path /mnt/project/LLAVA/image2schema_data/dataset_img2ir_v2/dataset/converted/yuque_5w_zh/datasets/ir.json \ 5 | --is_parquet True \ 6 | --output_dir /mnt/project/LLAVA/visual_tuning_weights/qwen-vl-14b-lora-sft-1205/ \ 7 | --image_folder /mnt/project/LLAVA/image2schema_data/dataset_img2ir_v2/dataset/converted/yuque_5w_zh/train/ \ 8 | --image_aspect_ratio None \ 9 | --lora_enable True \ 10 | --lora_r 16 \ 11 | --lora_alpha 64 \ 12 | --model_name_or_path /mnt/user/laiyan/salesgpt/model/Qwen-14B-Chat-VL \ 13 | --version qwen \ 14 | --vision_tower /mnt/project/LLAVA/Qwen-VL-visual/ \ 15 | --mm_projector_type cross_attn \ 16 | --pretrain_mm_mlp_adapter /mnt/project/LLAVA/pretrain_weights/qwen-vl-14b-cross-attn-1126-multigpu/checkpoint-9000/mm_projector.bin \ 17 | --mm_vision_select_layer -2 \ 18 | --mm_use_im_start_end False \ 19 | --use_im_start_end False \ 20 | --mm_use_im_patch_token False \ 21 | --bf16 True \ 22 | --num_train_epochs 5 \ 23 | --per_device_train_batch_size 4 \ 24 | --per_device_eval_batch_size 2 \ 25 | --gradient_accumulation_steps 1 \ 26 | --evaluation_strategy "steps" \ 27 | --eval_steps 1000 \ 28 | --load_best_model_at_end True \ 29 | --save_strategy "steps" \ 30 | --save_steps 1000 \ 31 | --save_total_limit 10 \ 32 | --freeze_backbone True \ 33 | --learning_rate 2e-5 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --lr_scheduler_type "cosine" \ 37 | --logging_steps 1 \ 38 | --tf32 True \ 39 | --model_max_length 2048 \ 40 | --gradient_checkpointing True \ 41 | --lazy_preprocess True \ 42 | --dataloader_num_workers 4 \ 43 | --report_to none -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model, load_mixed_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | #tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | tokenizer, model, image_processor, context_len = load_mixed_pretrained_model(args.model_path, args.model_base, model_name, args.vision_tower_path, args.mm_projector_type, args.mm_projector_path ,device_map='cpu') 10 | 11 | model.save_pretrained(args.save_model_path) 12 | tokenizer.save_pretrained(args.save_model_path) 13 | 14 | 15 | if __name__ == "__main__": 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--model-path", type=str, required=True) 18 | parser.add_argument("--model-base", type=str, required=True) 19 | parser.add_argument("--save-model-path", type=str, required=True) 20 | parser.add_argument("--mm-projector-type", type=str, required=True) 21 | parser.add_argument("--mm-projector-path", type=str, required=True) 22 | parser.add_argument("--vision-tower-path", type=str, required=True) 23 | 24 | args = parser.parse_args() 25 | 26 | merge_lora(args) 27 | -------------------------------------------------------------------------------- /scripts/merge_qwen_vl_weights.sh: -------------------------------------------------------------------------------- 1 | python llava/merge_pretrain_weights_to_qwenvl.py \ 2 | --LLM-path $PRETRAINED_MODEL_PATH \ 3 | --mm-projector-type cross_attn \ 4 | --mm-projector $PRETRAINED_MODEL_PATH/mm_projector/mm_projector.bin \ 5 | --vision-tower $PRETRAINED_MODEL_PATH/Qwen-VL-visual/ \ 6 | --output-path /path/to/save/model 7 | -------------------------------------------------------------------------------- /scripts/new_ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "train_batch_size": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "steps_per_print": 1, 6 | "gradient_clipping": 1.0, 7 | "zero_optimization": { 8 | "stage": 2, 9 | "offload_optimizer": { 10 | "device": "cpu", 11 | "pin_memory": true 12 | } 13 | }, 14 | "zero_allow_untested_optimizer": true, 15 | "bf16": { 16 | "enabled": true, 17 | "loss_scale": 0, 18 | "loss_scale_window": 1000, 19 | "hysteresis": 2, 20 | "min_loss_scale": 1 21 | }, 22 | "optimizer": { 23 | "type": "AdamW", 24 | "params": { 25 | "lr": "auto", 26 | "betas": "auto", 27 | "eps": "auto", 28 | "weight_decay": "auto" 29 | } 30 | }, 31 | "scheduler": { 32 | "type": "WarmupLR", 33 | "params": { 34 | "warmup_min_lr": "auto", 35 | "warmup_max_lr": "auto", 36 | "warmup_num_steps": "auto" 37 | } 38 | }, 39 | "activation_checkpointing": { 40 | "partition_activations": true, 41 | "contiguous_memory_optimization": true 42 | }, 43 | "wall_clock_breakdown": false 44 | } -------------------------------------------------------------------------------- /scripts/new_ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "scheduler": { 24 | "type": "WarmupLR", 25 | "params": { 26 | "warmup_min_lr": "auto", 27 | "warmup_max_lr": "auto", 28 | "warmup_num_steps": "auto" 29 | } 30 | }, 31 | 32 | "zero_optimization": { 33 | "stage": 3, 34 | "offload_optimizer": { 35 | "device": "none", 36 | "pin_memory": true 37 | }, 38 | "offload_param": { 39 | "device": "none", 40 | "pin_memory": true 41 | }, 42 | "overlap_comm": true, 43 | "contiguous_gradients": true, 44 | "sub_group_size": 1e9, 45 | "reduce_bucket_size": "auto", 46 | "stage3_prefetch_bucket_size": "auto", 47 | "stage3_param_persistence_threshold": "auto", 48 | "stage3_max_live_parameters": 1e9, 49 | "stage3_max_reuse_distance": 1e9, 50 | "stage3_gather_16bit_weights_on_model_save": true 51 | }, 52 | 53 | "gradient_accumulation_steps": "auto", 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 100, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | # MODEL_VERSION=vicuna-v1-3-7b 6 | # MODEL_VERSION=llama-2-7b-chat 7 | 8 | ########### DO NOT CHANGE ########### 9 | ########### USE THIS FOR BOTH ########### 10 | PROMPT_VERSION=plain 11 | ########### DO NOT CHANGE ########### 12 | 13 | deepspeed llava/train/train_mem.py \ 14 | --deepspeed ./scripts/zero2.json \ 15 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 16 | --version $PROMPT_VERSION \ 17 | --data_path /path/to/data/ \ 18 | --image_folder /path/to/images \ 19 | --vision_tower openai/clip-vit-large-patch14 \ 20 | --tune_mm_mlp_adapter True \ 21 | --mm_vision_select_layer -2 \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --bf16 True \ 25 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \ 26 | --num_train_epochs 1 \ 27 | --per_device_train_batch_size 16 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 1 \ 30 | --evaluation_strategy "no" \ 31 | --save_strategy "steps" \ 32 | --save_steps 24000 \ 33 | --save_total_limit 1 \ 34 | --learning_rate 2e-3 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.03 \ 37 | --lr_scheduler_type "cosine" \ 38 | --logging_steps 1 \ 39 | --tf32 True \ 40 | --model_max_length 2048 \ 41 | --gradient_checkpointing True \ 42 | --dataloader_num_workers 4 \ 43 | --lazy_preprocess True \ 44 | --report_to wandb 45 | -------------------------------------------------------------------------------- /scripts/pretrain_llava_qwen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | # MODEL_VERSION=vicuna-v1-3-7b 6 | #MODEL_VERSION=llama-2-7b-chat 7 | 8 | ########### DO NOT CHANGE ########### 9 | ########### USE THIS FOR BOTH ########### 10 | PROMPT_VERSION=plain 11 | ########### DO NOT CHANGE ########### 12 | 13 | torchrun --nproc_per_node 8 --nnodes 1 --node_rank 0 llava/train/train_mem.py \ 14 | --model_name_or_path /mnt/user/laiyan/salesgpt/model/Qwen-14B-Chat \ 15 | --deepspeed ./scripts/new_ds_config.json \ 16 | --version $PROMPT_VERSION \ 17 | --data_path /mnt/project/LLAVA/image2schema_data/dataset_img2ir_v2/dataset/converted/yuque_5w_zh/datasets/ \ 18 | --eval_path /mnt/project/LLAVA/image2schema_data/dataset_img2ir_v2/dataset/converted/yuque_5w_zh/datasets/ir.json \ 19 | --image_folder /mnt/project/LLAVA/image2schema_data/dataset_img2ir_v2/dataset/converted/yuque_5w_zh/train/ \ 20 | --vision_tower /mnt/project/LLAVA/cn_clip_336/ \ 21 | --mm_projector_type mlp3x_gelu \ 22 | --tune_mm_mlp_adapter True \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --use_im_start_end False \ 26 | --mm_use_im_patch_token False \ 27 | --bf16 True \ 28 | --output_dir /mnt/project/LLAVA/pretrain_weights/llava_qwen-14b-multigpu-10xlr/ \ 29 | --num_train_epochs 5 \ 30 | --per_device_train_batch_size 2 \ 31 | --per_device_eval_batch_size 2 \ 32 | --gradient_accumulation_steps 1 \ 33 | --evaluation_strategy "steps" \ 34 | --eval_steps 1000 \ 35 | --load_best_model_at_end True \ 36 | --save_strategy "steps" \ 37 | --save_steps 1000 \ 38 | --save_total_limit 1 \ 39 | --learning_rate 2e-2 \ 40 | --weight_decay 0. \ 41 | --warmup_ratio 0.03 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 1 \ 44 | --model_max_length 2048 \ 45 | --gradient_checkpointing True \ 46 | --dataloader_num_workers 4 \ 47 | --lazy_preprocess True \ 48 | --report_to none 49 | -------------------------------------------------------------------------------- /scripts/pretrain_multinode.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | # MODEL_VERSION=vicuna-v1-3-7b 6 | #MODEL_VERSION=llama-2-7b-chat 7 | 8 | ########### DO NOT CHANGE ########### 9 | ########### USE THIS FOR BOTH ########### 10 | PROMPT_VERSION=qwen 11 | ########### DO NOT CHANGE ########### 12 | 13 | accelerate launch \ 14 | --num_machines 8 \ 15 | --num_processes 64 \ 16 | --use_deepspeed \ 17 | --deepspeed_multinode_launcher 'standard' \ 18 | --zero_stage 2 \ 19 | --offload_optimizer_device 'cpu' \ 20 | --offload_param_device 'none' \ 21 | --gradient_accumulation_steps 1 \ 22 | --gradient_clipping 1.0 \ 23 | --zero3_init_flag false \ 24 | --zero3_save_16bit_model false \ 25 | --main_training_function 'main' \ 26 | --mixed_precision 'bf16' \ 27 | --dynamo_backend 'no' \ 28 | --same_network \ 29 | --machine_rank $RANK \ 30 | --main_process_ip $MASTER_ADDR \ 31 | --main_process_port $MASTER_PORT \ 32 | --rdzv_backend 'static' \ 33 | llava/train/train_mem.py \ 34 | --model_name_or_path /path/to/model/ \ 35 | --version $PROMPT_VERSION \ 36 | --data_path /path/to/train/data/ \ 37 | --eval_path /path/to/eval/data/ \ 38 | --is_parquet True \ 39 | --image_folder /path/to/images/ \ 40 | --image_aspect_ratio None \ 41 | --vision_tower /path/to/visiontower/ \ 42 | --mm_projector_type cross_attn \ 43 | --tune_mm_mlp_adapter True \ 44 | --mm_vision_select_layer -2 \ 45 | --mm_use_im_start_end False \ 46 | --use_im_start_end False \ 47 | --mm_use_im_patch_token False \ 48 | --bf16 True \ 49 | --output_dir /path/to/save/weights/ \ 50 | --num_train_epochs 10 \ 51 | --per_device_train_batch_size 8 \ 52 | --per_device_eval_batch_size 2 \ 53 | --gradient_accumulation_steps 1 \ 54 | --evaluation_strategy "steps" \ 55 | --eval_steps 1000 \ 56 | --load_best_model_at_end True \ 57 | --save_strategy "steps" \ 58 | --save_steps 1000 \ 59 | --save_total_limit 20 \ 60 | --learning_rate 1e-4 \ 61 | --weight_decay 0. \ 62 | --adam_beta2 0.95 \ 63 | --warmup_ratio 0.03 \ 64 | --lr_scheduler_type "cosine" \ 65 | --logging_steps 1 \ 66 | --model_max_length 2048 \ 67 | --gradient_checkpointing True \ 68 | --dataloader_num_workers 4 \ 69 | --lazy_preprocess True \ 70 | --report_to none 71 | -------------------------------------------------------------------------------- /scripts/render_sketch2code.sh: -------------------------------------------------------------------------------- 1 | python llava/serve/demo.py --input-file /mnt/project/LLAVA/image2schema_data/sft_data/sketch2code/results/sketch2code_ans_v1.jsonl \ 2 | --output-folder /mnt/project/LLAVA/image2schema_data/sft_data/sketch2code/render_results/ \ 3 | --dsl-path /ossfs/node_32682640/workspace/MFTCoderVL/MFTCoderVL/llava/serve/assets/web-dsl-mapping.json -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------