├── .devcontainer ├── Dockerfile ├── devcontainer.env ├── devcontainer.json └── postCreateCommand.sh ├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── 1-usage.yaml │ ├── 2-feature-request.yaml │ ├── 3-question.yaml │ └── 4-discussion.yaml ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── docs ├── Customize_Component.md ├── Data.md ├── Evaluation.md ├── Evaluation_image.md ├── Finetune_Custom_Data.md ├── Intel.md ├── MODEL_ZOO.md ├── Windows.md ├── macOS.md └── study_llm_backbone.md ├── images ├── all-model-compare.png ├── demo_cli.gif ├── llava-compare.png ├── llava_example_cmp.png └── vip-llava_arch.png ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_loader_vip.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_qbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── configuration_phi3.py │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ ├── llava_phi3.py │ │ └── modeling_phi3.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_4layer_encoder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── cli_vip.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py ├── utils.py ├── visual_prompt_generator.py └── visual_prompt_organizer.py ├── playground └── data │ ├── eval │ └── vip-bench-example-results │ │ └── vip-llava-7b-human.json │ └── prompts │ ├── refcocog.text │ └── vg.text ├── predict.py ├── pyproject.toml └── scripts ├── convert_vipbench_for_eval.py ├── eval ├── pointQA.sh ├── v7w.sh ├── vcr_qa.sh ├── vcr_qar.sh ├── vip-bench_evaluator.py └── vipbench.sh ├── finetune_llava_1_5_llama3.sh ├── finetune_llava_1_5_phi3.sh ├── finetune_stage2.sh ├── finetune_stage2_lora.sh ├── finetune_stage3.sh ├── finetune_task.sh ├── finetune_task_lora.sh ├── finetune_vip_llava_llama3_stage2.sh ├── finetune_vip_llava_llama3_stage3.sh ├── finetune_vip_llava_phi3_stage2.sh ├── finetune_vip_llava_phi3_stage3.sh ├── pretrain.sh ├── pretrain_llava_1_5_llama3.sh ├── pretrain_llava_1_5_phi3.sh ├── pretrain_vip_llava_llama3.sh ├── pretrain_vip_llava_phi3.sh ├── v1_5 └── eval │ ├── gqa.sh │ ├── llavabench.sh │ ├── mmbench.sh │ ├── mmbench_cn.sh │ ├── mme.sh │ ├── mmvet.sh │ ├── pope.sh │ ├── qbench.sh │ ├── qbench_zh.sh │ ├── seed-img.sh │ ├── seed-process-anno.py │ ├── seed.sh │ ├── sqa.sh │ ├── textvqa.sh │ ├── vizwiz.sh │ └── vqav2.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mcr.microsoft.com/devcontainers/base:ubuntu-20.04 2 | 3 | SHELL [ "bash", "-c" ] 4 | 5 | # update apt and install packages 6 | RUN apt update && \ 7 | apt install -yq \ 8 | ffmpeg \ 9 | dkms \ 10 | build-essential 11 | 12 | # add user tools 13 | RUN sudo apt install -yq \ 14 | jq \ 15 | jp \ 16 | tree \ 17 | tldr 18 | 19 | # add git-lfs and install 20 | RUN curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash && \ 21 | sudo apt-get install -yq git-lfs && \ 22 | git lfs install 23 | 24 | ############################################ 25 | # Setup user 26 | ############################################ 27 | 28 | USER vscode 29 | 30 | # install azcopy, a tool to copy to/from blob storage 31 | # for more info: https://learn.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-blobs-upload#upload-a-file 32 | RUN cd /tmp && \ 33 | wget https://azcopyvnext.azureedge.net/release20230123/azcopy_linux_amd64_10.17.0.tar.gz && \ 34 | tar xvf azcopy_linux_amd64_10.17.0.tar.gz && \ 35 | mkdir -p ~/.local/bin && \ 36 | mv azcopy_linux_amd64_10.17.0/azcopy ~/.local/bin && \ 37 | chmod +x ~/.local/bin/azcopy && \ 38 | rm -rf azcopy_linux_amd64* 39 | 40 | # Setup conda 41 | RUN cd /tmp && \ 42 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ 43 | bash ./Miniconda3-latest-Linux-x86_64.sh -b && \ 44 | rm ./Miniconda3-latest-Linux-x86_64.sh 45 | 46 | # Install dotnet 47 | RUN cd /tmp && \ 48 | wget https://dot.net/v1/dotnet-install.sh && \ 49 | chmod +x dotnet-install.sh && \ 50 | ./dotnet-install.sh --channel 7.0 && \ 51 | ./dotnet-install.sh --channel 3.1 && \ 52 | rm ./dotnet-install.sh 53 | 54 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.env: -------------------------------------------------------------------------------- 1 | SAMPLE_ENV_VAR1="Sample Value" 2 | SAMPLE_ENV_VAR2=332431bf-68bf -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "ViP-LLaVA", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "context": "..", 6 | "args": {} 7 | }, 8 | "features": { 9 | "ghcr.io/devcontainers/features/docker-in-docker:2": {}, 10 | "ghcr.io/devcontainers/features/azure-cli:1": {}, 11 | "ghcr.io/azure/azure-dev/azd:0": {}, 12 | "ghcr.io/devcontainers/features/powershell:1": {}, 13 | "ghcr.io/devcontainers/features/common-utils:2": {}, 14 | "ghcr.io/devcontainers-contrib/features/zsh-plugins:0": {}, 15 | }, 16 | // "forwardPorts": [], 17 | "postCreateCommand": "bash ./.devcontainer/postCreateCommand.sh", 18 | "customizations": { 19 | "vscode": { 20 | "settings": { 21 | "python.analysis.autoImportCompletions": true, 22 | "python.analysis.autoImportUserSymbols": true, 23 | "python.defaultInterpreterPath": "~/miniconda3/envs/vip-llava/bin/python", 24 | "python.formatting.provider": "yapf", 25 | "python.linting.enabled": true, 26 | "python.linting.flake8Enabled": true, 27 | "isort.check": true, 28 | "dev.containers.copyGitConfig": true, 29 | "terminal.integrated.defaultProfile.linux": "zsh", 30 | "terminal.integrated.profiles.linux": { 31 | "zsh": { 32 | "path": "/usr/bin/zsh" 33 | }, 34 | } 35 | }, 36 | "extensions": [ 37 | "aaron-bond.better-comments", 38 | "eamodio.gitlens", 39 | "EditorConfig.EditorConfig", 40 | "foxundermoon.shell-format", 41 | "GitHub.copilot-chat", 42 | "GitHub.copilot-labs", 43 | "GitHub.copilot", 44 | "lehoanganh298.json-lines-viewer", 45 | "mhutchie.git-graph", 46 | "ms-azuretools.vscode-docker", 47 | "ms-dotnettools.dotnet-interactive-vscode", 48 | "ms-python.flake8", 49 | "ms-python.isort", 50 | "ms-python.python", 51 | "ms-python.vscode-pylance", 52 | "njpwerner.autodocstring", 53 | "redhat.vscode-yaml", 54 | "stkb.rewrap", 55 | "yzhang.markdown-all-in-one", 56 | ] 57 | } 58 | }, 59 | "mounts": [], 60 | "runArgs": [ 61 | "--gpus", 62 | "all", 63 | // "--ipc", 64 | // "host", 65 | "--ulimit", 66 | "memlock=-1", 67 | "--env-file", 68 | ".devcontainer/devcontainer.env" 69 | ], 70 | // "remoteUser": "root" 71 | } 72 | -------------------------------------------------------------------------------- /.devcontainer/postCreateCommand.sh: -------------------------------------------------------------------------------- 1 | git config --global safe.directory '*' 2 | git config --global core.editor "code --wait" 3 | git config --global pager.branch false 4 | 5 | # Set AZCOPY concurrency to auto 6 | echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.zshrc 7 | echo "export AZCOPY_CONCURRENCY_VALUE=AUTO" >> ~/.bashrc 8 | 9 | # Activate conda by default 10 | echo ". /home/vscode/miniconda3/bin/activate" >> ~/.zshrc 11 | echo ". /home/vscode/miniconda3/bin/activate" >> ~/.bashrc 12 | 13 | # Use vip-llava environment by default 14 | echo "conda activate vip-llava" >> ~/.zshrc 15 | echo "conda activate vip-llava" >> ~/.bashrc 16 | 17 | # Add dotnet to PATH 18 | echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.bashrc 19 | echo 'export PATH="$PATH:$HOME/.dotnet"' >> ~/.zshrc 20 | 21 | # Create and activate vip-llava environment 22 | source /home/vscode/miniconda3/bin/activate 23 | conda create -y -q -n vip-llava python=3.10 24 | conda activate vip-llava 25 | 26 | # Install Nvidia Cuda Compiler 27 | conda install -y -c nvidia cuda-compiler 28 | 29 | pip install pre-commit==3.0.2 30 | 31 | # Install package locally 32 | pip install --upgrade pip # enable PEP 660 support 33 | pip install -e . 34 | 35 | # Install additional packages for training 36 | pip install -e ".[train]" 37 | pip install flash-attn --no-build-isolation 38 | 39 | # Download checkpoints to location outside of the repo 40 | git clone https://huggingface.co/mucai/vip-llava-7b ~/vip-llava-7b 41 | 42 | # Commented because it is unlikely for users to have enough local GPU memory to load the model 43 | # git clone https://huggingface.co/mucai/vip-llava-13b ~/vip-llava-13b 44 | 45 | echo "postCreateCommand.sh COMPLETE!" 46 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | /mucai 22 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-usage.yaml: -------------------------------------------------------------------------------- 1 | name: Usage issues 2 | description: Report issues in usage. 3 | title: "[Usage] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this form. Please give as detailed description as possible for us to better assist with the issue :) 9 | - type: textarea 10 | id: what-happened 11 | attributes: 12 | label: Describe the issue 13 | description: Please give as detailed description as possible for us to better assist with the issue. Please paste the **FULL** error log here, so that we can better understand the issue. Wrap the log with ``` for better readability in GitHub. 14 | placeholder: Issue 15 | value: | 16 | Issue: 17 | 18 | Command: 19 | ``` 20 | PASTE THE COMMANDS HERE. 21 | ``` 22 | 23 | Log: 24 | ``` 25 | PASTE THE LOGS HERE. 26 | ``` 27 | 28 | Screenshots: 29 | You may attach screenshots if it better explains the issue. 30 | validations: 31 | required: true 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Request for a new feature 3 | title: "[Feature request] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. Please share your thoughts of the new features below. 9 | - type: textarea 10 | id: feature 11 | attributes: 12 | label: feature 13 | placeholder: Start your thoughts here... -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-question.yaml: -------------------------------------------------------------------------------- 1 | name: Questions 2 | description: General questions about the work 3 | title: "[Question] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/mu-cai/ViP-LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :) 9 | - type: textarea 10 | id: question 11 | attributes: 12 | label: Question 13 | placeholder: Start question here... -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/4-discussion.yaml: -------------------------------------------------------------------------------- 1 | name: Discussions 2 | description: General discussions about the work 3 | title: "[Discussion] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/mu-cai/ViP-LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :) 9 | - type: textarea 10 | id: discussion 11 | attributes: 12 | label: Discussion 13 | placeholder: Start discussion here... -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.json 11 | *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | 16 | # Editor 17 | .idea 18 | *.swp 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | 25 | checkpoints 26 | ckpts* 27 | 28 | .ipynb_checkpoints 29 | *.ipynb 30 | 31 | # DevContainer 32 | !.devcontainer/* 33 | 34 | # Demo 35 | serve_images/ 36 | 37 | playground/data/ 38 | images/gradio* 39 | tmp* 40 | dataset 41 | *chtc* 42 | *debug* 43 | llava/model/language_model/mpt_old -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.40.0" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | - "openai==0.28" 34 | run: 35 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 36 | 37 | # predict.py defines how predictions are run on your model 38 | predict: "predict.py:Predictor" 39 | -------------------------------------------------------------------------------- /docs/Customize_Component.md: -------------------------------------------------------------------------------- 1 | # Customize Components in ViP-LLaVA 2 | 3 | This is an initial guide on how to replace the LLMs, visual encoders, etc. with your choice of components. 4 | 5 | ## LLM 6 | 7 | It is quite simple to swap out LLaMA to any other LLMs. You can refer to our implementation of [`llava_llama.py`](https://raw.githubusercontent.com/mu-cai/ViP-LLaVA/main/llava/model/language_model/llava_llama.py) for an example of how to replace the LLM. 8 | 9 | Although it may seem that it still needs ~100 lines of code, most of them are copied from the original `llama.py` from HF. The only part that is different is to insert some lines for processing the multimodal inputs. 10 | 11 | In `forward` function, you can see that we call `self.prepare_inputs_labels_for_multimodal` to process the multimodal inputs. This function is defined in `LlavaMetaForCausalLM` and you just need to insert it into the `forward` function of your LLM. 12 | 13 | In `prepare_inputs_for_generation` function, you can see that we add `images` to the `model_inputs`. This is because we need to pass the images to the LLM during generation. 14 | 15 | These are basically all the changes you need to make to replace the LLM. 16 | 17 | ## Visual Encoder 18 | 19 | You can check out [`clip_encoder.py`](https://github.com/mu-cai/ViP-LLaVA/blob/main/llava/model/multimodal_encoder/clip_encoder.py) on how we implement the CLIP visual encoder. 20 | 21 | -------------------------------------------------------------------------------- /docs/Data.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | 3 | 4 | ### Pretraining Dataset 5 | The pretraining dataset used in this release is a subset of CC-3M dataset, filtered with a more balanced concept coverage distribution. Please see [here](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K) for a detailed description of the dataset structure and how to download the images. 6 | 7 | If you already have CC-3M dataset on your disk, the image names follow this format: `GCC_train_000000000.jpg`. You may edit the `image` field correspondingly if necessary. 8 | 9 | | Data | Chat File | Meta Data | Size | 10 | | --- | --- | --- | ---: | 11 | | CC-3M Concept-balanced 595K | [chat.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) | [metadata.json](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/metadata.json) | 211 MB 12 | | LAION/CC/SBU BLIP-Caption Concept-balanced 558K | [blip_laion_cc_sbu_558k.json](https://huggingface.co/datasets/liuhaotian/LLaVA-Pretrain/blob/main/blip_laion_cc_sbu_558k.json) | [metadata.json](#) | 181 MB 13 | 14 | **Important notice**: Upon the request from the community, as ~15% images of the original CC-3M dataset are no longer accessible, we upload [`images.zip`](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/images.zip) for better reproducing our work in research community. It must not be used for any other purposes. The use of these images must comply with the CC-3M license. This may be taken down at any time when requested by the original CC-3M dataset owner or owners of the referenced images. -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | # ViP-Bench 4 | 5 | 1. Extract contents of [`ViP-Bench`](https://huggingface.co/datasets/mucai/ViP-Bench) to `./playground/data/eval/ViP-Bench`. 6 | 2. Single-GPU inference and evaluate for bbox and human drawn visual prompts, respectively. 7 | ```Shell 8 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/vipbench.sh bbox 9 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/vipbench.sh human 10 | ``` 11 | Optionally, Change the model name from `vip-llava-7b` to other LLaVA or ViP-LLaVA models. 12 | 13 | 3. Submit the results to the [evaluation server](https://huggingface.co/spaces/mucai/ViP-Bench_Evaluator): `./playground/data/eval/ViP-Bench/results/vip-llava-7b-human.json`. 14 | 15 | 16 | Optionally, see [here](https://github.com/mu-cai/ViP-LLaVA/blob/main/scripts/eval/vip-bench_evaluator.py), which is an evaluation script using your own openai key. 17 | 18 | ## Source annotation 19 | 20 | In `source_image`, we provide the source plain images along with the bounding box/mask annotations. Researchers can use such grounding information to match the special tokens such as `` in `"question"` entry of `vip-bench-meta-data.json`. For example, `` can be replaced by textual coordinates to evaluate the region-level multimodal models. 21 | 22 | 23 | 24 | 25 | 26 | # Academic Benchmarks 27 | 28 | Please download the evaluation `json` dataset [here](https://huggingface.co/datasets/mucai/ViP-LLaVA-Instruct/tree/main). 29 | 30 | ## Visusl7W 31 | 32 | ```Shell 33 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/v7w.sh 34 | ``` 35 | 36 | 37 | ## PointQA-LookTwice 38 | 39 | ```Shell 40 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/pointQA.sh 41 | ``` 42 | 43 | 44 | ## Visual Commonsense Reasoning 45 | 46 | For Q -> A: 47 | ```Shell 48 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/vcr-qa.sh 49 | ``` 50 | 51 | For QA -> R: 52 | ```Shell 53 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/vcr-qar.sh 54 | ``` 55 | 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /docs/Finetune_Custom_Data.md: -------------------------------------------------------------------------------- 1 | # Finetune ViP-LLaVA on Custom Datasets 2 | 3 | ## Dataset Format 4 | 5 | Convert your data to a JSON file of a List of all samples. Sample metadata should contain `id` (a unique identifier), `image` (the path to the image), `conversations` (the conversation data between human and AI), `bboxes` (a list of bounding boxes), and optionally `segmentations` (a list of segmentation masks). Note that the image can be annotaed with arbitrary visual prompts. Also try using different visual prompts. 6 | 7 | A sample JSON for finetuning ViP-LLaVA for generating tag-style captions for Stable Diffusion: 8 | 9 | ```json 10 | [ 11 | { 12 | "id": "997bb945-628d-4724-b370-b84de974a19f", 13 | "image": "part-000001/997bb945-628d-4724-b370-b84de974a19f.jpg", 14 | "bbox": [ [25, 30, 120, 150], [35, 23, 134, 213] ] 15 | "conversations": [ 16 | { 17 | "from": "human", 18 | "value": "\nWrite a prompt for Stable Diffusion to generate this image, foucsing on and ." 19 | }, 20 | { 21 | "from": "gpt", 22 | "value": "a beautiful painting of chernobyl by nekro, pascal blanche, john harris, greg rutkowski, sin jong hun, moebius, simon stalenhag. in style of cg art. ray tracing. cel shading. hyper detailed. realistic. ue 5. maya. octane render. " 23 | }, 24 | ] 25 | }, 26 | ... 27 | ] 28 | ``` 29 | 30 | ## Command 31 | 32 | If you have a limited task-specific data, we recommend finetuning from ViP-LLaVA checkpoints with LoRA or fully finetune by with a small learning rate like 2e-6 or 2e-7. 33 | 34 | If the amount of the task-specific data is sufficient, you can also finetune from ViP-LLaVA checkpoints with full-model. 35 | 36 | You may need to adjust the hyperparameters to fit each specific dataset and your hardware constraint. 37 | 38 | 39 | -------------------------------------------------------------------------------- /docs/Intel.md: -------------------------------------------------------------------------------- 1 | # Intel Platforms 2 | 3 | * Support [Intel GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) 4 | * Support [Intel CPU Sapphire Rapides](https://ark.intel.com/content/www/us/en/ark/products/codename/126212/products-formerly-sapphire-rapids.html) 5 | * Based on [Intel Extention for Pytorch](https://intel.github.io/intel-extension-for-pytorch) 6 | 7 | More details in [**intel branch**](https://github.com/mu-cai/ViP-LLaVA/tree/intel/docs/intel) -------------------------------------------------------------------------------- /docs/MODEL_ZOO.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | **To Use ViP-LLaVA checkpoints, your llava package version must be newer than 1.1.0.** 4 | 5 | If you are interested in including any other details in Model Zoo, please open an issue :) 6 | 7 | The model weights below are *merged* weights. You do not need to apply delta. The usage of ViP-LLaVA checkpoints should comply with the base LLM's model license: [Llama 2](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md). 8 | 9 | ## ViP-LLaVA 10 | 11 | | Version | Size | Schedule | Checkpoint | Visual7W | PointQA-LookTwice | RegionBench@Box | RegionBench@Human 12 | |----------|----------|-----------|-----------|---|---|---|---| 13 | | ViP-LLaVA | 7B | full_ft-1e | [mucai/vip-llava-7b](https://huggingface.co/mucai/vip-llava-7b) | 86.09 | 71.31 | 48.4 | 48.3 | 14 | | ViP-LLaVA | 13B | full_ft-1e | [mucai/vip-llava-13b](https://huggingface.co/mucai/vip-llava-13b) | 88.28 | 71.77 | 48.3 | 48.2 | 15 | 16 | Base model: Vicuna v1.5. 17 | 18 | 19 | ## ViP-LLaVA Stage 2 Checkpoints 20 | 21 | [mucai/vip-llava-7b-base](https://huggingface.co/mucai/vip-llava-7b-base) 22 | 23 | [mucai/vip-llava-13b-base](https://huggingface.co/mucai/vip-llava-13b-base) 24 | 25 | 26 | ## Projector weights 27 | 28 | These are projector weights we have pretrained. You can use these projector weights for visual instruction tuning. They are just pretrained on image-text pairs and are NOT instruction-tuned, which means they do NOT follow instructions as well as our official models and can output repetitive, lengthy, and garbled outputs. 29 | 30 | 31 | NOTE: When you use our pretrained projector for visual instruction tuning, it is very important to use the same base LLM and vision encoder as the one we used for pretraining the projector. Otherwise, the performance will be very poor. 32 | 33 | When using these projector weights to instruction-tune your LMM, please make sure that these options are correctly set as follows, 34 | 35 | ```Shell 36 | --mm_use_im_start_end False 37 | --mm_use_im_patch_token False 38 | ``` 39 | 40 | [Projector for Vicuna-1.5 7B](https://huggingface.co/mucai/vip-llava-7b-pretrain) 41 | 42 | [Projector for Vicuna-1.5 13B](https://huggingface.co/mucai/vip-llava-13b-pretrain) 43 | 44 | 45 | ## VCR checkpoint 46 | 47 | [Checkpoint](https://huggingface.co/mucai/vip-llava-7b-base-vcr-ft) 48 | 49 | 50 | ## RefCOCOg Region Captioning checkpoint 51 | 52 | [Checkpoint](https://huggingface.co/mucai/vip-llava-7b-refcocog-ft) 53 | -------------------------------------------------------------------------------- /docs/Windows.md: -------------------------------------------------------------------------------- 1 | # Run ViP-LLaVA on Windows 2 | 3 | *NOTE: ViP-LLaVA on Windows is not fully supported. Currently we only support 16-bit inference. For a more complete support, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install) for now. More functionalities on Windows is to be added soon, stay tuned.* 4 | 5 | ## Installation 6 | 7 | 1. Clone this repository and navigate to ViP-LLaVA folder 8 | ```bash 9 | git clone https://github.com/mu-cai/ViP-LLaVA.git 10 | cd ViP-LLaVA 11 | ``` 12 | 13 | 2. Install Package 14 | ```Shell 15 | conda create -n vip-llava python=3.10 -y 16 | conda activate vip-llava 17 | python -mpip install --upgrade pip # enable PEP 660 support 18 | pip install torch==2.0.1+cu117 torchvision==0.15.2+cu117 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117 19 | pip install -e . 20 | pip uninstall bitsandbytes 21 | ``` 22 | 23 | ## Run demo 24 | 25 | See instructions [here](https://github.com/mu-cai/ViP-LLaVA#demo). 26 | 27 | Note that quantization (4-bit, 8-bit) is *NOT* supported on Windows. Stay tuned for the 4-bit support on Windows! 28 | -------------------------------------------------------------------------------- /docs/macOS.md: -------------------------------------------------------------------------------- 1 | # Run ViP-LLaVA on macOS 2 | 3 | *NOTE: ViP-LLaVA on macOS is not fully supported. Currently we only support 16-bit inference. More functionalities on macOS is to be added soon, stay tuned.* 4 | 5 | ## Installation 6 | 7 | 1. Clone this repository and navigate to ViP-LLaVA folder 8 | ```bash 9 | git clone https://github.com/mu-cai/ViP-LLaVA.git 10 | cd ViP-LLaVA 11 | ``` 12 | 13 | 2. Install Package 14 | ```Shell 15 | conda create -n vip-llava python=3.10 -y 16 | conda activate vip-llava 17 | python -mpip install --upgrade pip # enable PEP 660 support 18 | pip install -e . 19 | pip install torch==2.1.0 torchvision==0.16.0 20 | pip uninstall bitsandbytes 21 | ``` 22 | 23 | ## Run demo 24 | 25 | Specify `--device mps` when launching model worker or CLI. 26 | 27 | See instructions [here](https://github.com/mu-cai/ViP-LLaVA#demo). 28 | 29 | Note that quantization (4-bit, 8-bit) is *NOT* supported on macOS. Stay tuned for the 4-bit support on macOS! 30 | -------------------------------------------------------------------------------- /images/all-model-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/images/all-model-compare.png -------------------------------------------------------------------------------- /images/demo_cli.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/images/demo_cli.gif -------------------------------------------------------------------------------- /images/llava-compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/images/llava-compare.png -------------------------------------------------------------------------------- /images/llava_example_cmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/images/llava_example_cmp.png -------------------------------------------------------------------------------- /images/vip-llava_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/images/vip-llava_arch.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM, LlavaPhi3ForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_image 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | if "llama-3" in model_name.lower(): 35 | args.conv_mode = "llava_llama_3" 36 | elif 'phi-3' in model_name.lower(): 37 | args.conv_mode = "llava_phi_3" 38 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 39 | 40 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 41 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 42 | answers_file = os.path.expanduser(args.answers_file) 43 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 44 | ans_file = open(answers_file, "w") 45 | if "llama-3" in model_name.lower(): 46 | terminators = [ 47 | tokenizer.eos_token_id, 48 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 49 | ] 50 | elif 'phi-3' in model_name.lower(): 51 | terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] 52 | else: 53 | terminators = [tokenizer.eos_token_id,] 54 | for line in tqdm(questions): 55 | idx = line["question_id"] 56 | image_file = line["image"] 57 | qs = line["text"] 58 | cur_prompt = qs 59 | if model.config.mm_use_im_start_end: 60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 61 | else: 62 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 63 | 64 | conv = conv_templates[args.conv_mode].copy() 65 | conv.append_message(conv.roles[0], qs) 66 | conv.append_message(conv.roles[1], None) 67 | prompt = conv.get_prompt() 68 | 69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 70 | 71 | image = Image.open(os.path.join(args.image_folder, image_file)) 72 | image = process_image(image, args.image_preprocess , image_processor) 73 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 74 | 75 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 76 | keywords = [stop_str] 77 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 78 | if 'Phi-3' in model_name: 79 | stop_str= 'xsafs' 80 | with torch.inference_mode(): 81 | output_ids = model.generate( 82 | input_ids, 83 | images=image_tensor.unsqueeze(0).half().cuda(), 84 | do_sample=True if args.temperature > 0 else False, 85 | temperature=args.temperature, 86 | top_p=args.top_p, 87 | num_beams=args.num_beams, 88 | # no_repeat_ngram_size=3, 89 | eos_token_id=terminators, 90 | max_new_tokens=1024, 91 | use_cache=True) 92 | 93 | input_token_len = input_ids.shape[1] 94 | # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 95 | # if n_diff_input_output > 0: 96 | # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 97 | # outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 98 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 99 | outputs = outputs.strip() 100 | if outputs.endswith(stop_str): 101 | outputs = outputs[:-len(stop_str)] 102 | outputs = outputs.strip() 103 | ans_id = shortuuid.uuid() 104 | ans_file.write(json.dumps({"question_id": idx, 105 | "prompt": cur_prompt, 106 | "text": outputs, 107 | "answer_id": ans_id, 108 | "model_id": model_name, 109 | "metadata": {}}) + "\n") 110 | ans_file.flush() 111 | ans_file.close() 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 116 | parser.add_argument("--model-base", type=str, default=None) 117 | parser.add_argument("--image-folder", type=str, default="") 118 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 119 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 120 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 121 | parser.add_argument("--num-chunks", type=int, default=1) 122 | parser.add_argument("--chunk-idx", type=int, default=0) 123 | parser.add_argument("--temperature", type=float, default=0.2) 124 | parser.add_argument("--top_p", type=float, default=None) 125 | parser.add_argument("--num_beams", type=int, default=1) 126 | parser.add_argument("--image_preprocess", type=str, default="pad") 127 | args = parser.parse_args() 128 | 129 | eval_model(args) 130 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_qbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import json 5 | 6 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 7 | from llava.conversation import conv_templates, SeparatorStyle 8 | from llava.model.builder import load_pretrained_model 9 | from llava.utils import disable_torch_init 10 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 11 | 12 | from PIL import Image 13 | 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | 19 | def load_image(image_file): 20 | if image_file.startswith('http') or image_file.startswith('https'): 21 | response = requests.get(image_file) 22 | image = Image.open(BytesIO(response.content)).convert('RGB') 23 | else: 24 | image = Image.open(image_file).convert('RGB') 25 | return image 26 | 27 | 28 | def eval_model(args): 29 | # Model 30 | disable_torch_init() 31 | 32 | model_name = get_model_name_from_path(args.model_path) 33 | if "llama-3" in model_name.lower(): 34 | args.conv_mode = "llava_llama_3" 35 | elif 'phi-3' in model_name.lower(): 36 | args.conv_mode = "llava_phi_3" 37 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True) 38 | 39 | 40 | 41 | 42 | with open(args.questions_file) as f: 43 | llvqa_data = json.load(f) 44 | if "llama-3" in model_name.lower(): 45 | terminators = [ 46 | tokenizer.eos_token_id, 47 | tokenizer.convert_tokens_to_ids("<|eot_id|>") 48 | ] 49 | elif 'phi-3' in model_name.lower(): 50 | terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")] 51 | else: 52 | terminators = [tokenizer.eos_token_id,] 53 | for i, llddata in enumerate(tqdm(llvqa_data)): 54 | filename = llddata["img_path"] 55 | if args.lang == "en": 56 | message = llddata["question"] + "\nChoose between one of the options as follows:\n" 57 | elif args.lang == "zh": 58 | message = llddata["question"] + "\在下列选项中选择一个:\n" 59 | else: 60 | raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.") 61 | for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): 62 | message += f"{choice} {ans}\n" 63 | qs = message 64 | 65 | if model.config.mm_use_im_start_end: 66 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 67 | else: 68 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 69 | 70 | if 'llama-2' in model_name.lower(): 71 | conv_mode = "llava_llama_2" 72 | elif "v1" in model_name.lower(): 73 | conv_mode = "llava_v1" 74 | elif "mpt" in model_name.lower(): 75 | conv_mode = "mpt" 76 | else: 77 | conv_mode = "llava_v0" 78 | 79 | if args.conv_mode is not None and conv_mode != args.conv_mode: 80 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 81 | else: 82 | args.conv_mode = conv_mode 83 | 84 | conv = conv_templates[args.conv_mode].copy() 85 | conv.append_message(conv.roles[0], qs) 86 | conv.append_message(conv.roles[1], None) 87 | prompt = conv.get_prompt() 88 | 89 | image = load_image(args.image_folder + filename) 90 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 91 | 92 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 93 | 94 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 95 | keywords = [stop_str] 96 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 97 | 98 | 99 | with torch.inference_mode(): 100 | output_ids = model.generate( 101 | input_ids, 102 | images=image_tensor, 103 | num_beams=1, 104 | do_sample=False, 105 | temperature=0, 106 | max_new_tokens=1024, 107 | use_cache=True, 108 | eos_token_id=terminators, 109 | stopping_criteria=[stopping_criteria]) 110 | 111 | input_token_len = input_ids.shape[1] 112 | # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 113 | # if n_diff_input_output > 0: 114 | # print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 115 | # outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 116 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 117 | outputs = outputs.strip() 118 | if outputs.endswith(stop_str): 119 | outputs = outputs[:-len(stop_str)] 120 | outputs = outputs.strip() 121 | llddata["response"] = outputs 122 | with open(args.answers_file, "a") as wf: 123 | json.dump(llddata, wf) 124 | 125 | if __name__ == "__main__": 126 | parser = argparse.ArgumentParser() 127 | parser.add_argument("--model-path", type=str, default="vip-llava-v1.5") 128 | parser.add_argument("--model-base", type=str, default=None) 129 | parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa") 130 | parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json") 131 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 132 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 133 | parser.add_argument("--lang", type=str, default="en") 134 | args = parser.parse_args() 135 | 136 | eval_model(args) 137 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | KeywordsStoppingCriteria, 19 | ) 20 | 21 | from PIL import Image 22 | 23 | import requests 24 | from PIL import Image 25 | from io import BytesIO 26 | import re 27 | 28 | 29 | def image_parser(args): 30 | out = args.image_file.split(args.sep) 31 | return out 32 | 33 | 34 | def load_image(image_file): 35 | if image_file.startswith("http") or image_file.startswith("https"): 36 | response = requests.get(image_file) 37 | image = Image.open(BytesIO(response.content)).convert("RGB") 38 | else: 39 | image = Image.open(image_file).convert("RGB") 40 | return image 41 | 42 | 43 | def load_images(image_files): 44 | out = [] 45 | for image_file in image_files: 46 | image = load_image(image_file) 47 | out.append(image) 48 | return out 49 | 50 | 51 | def eval_model(args): 52 | # Model 53 | disable_torch_init() 54 | 55 | model_name = get_model_name_from_path(args.model_path) 56 | tokenizer, model, image_processor, context_len = load_pretrained_model( 57 | args.model_path, args.model_base, model_name 58 | ) 59 | 60 | qs = args.query 61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 62 | if IMAGE_PLACEHOLDER in qs: 63 | if model.config.mm_use_im_start_end: 64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 65 | else: 66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 67 | else: 68 | if model.config.mm_use_im_start_end: 69 | qs = image_token_se + "\n" + qs 70 | else: 71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 72 | 73 | if "llama-2" in model_name.lower(): 74 | conv_mode = "llava_llama_2" 75 | elif "v1" in model_name.lower(): 76 | conv_mode = "llava_v1" 77 | elif "mpt" in model_name.lower(): 78 | conv_mode = "mpt" 79 | else: 80 | conv_mode = "llava_v0" 81 | 82 | if args.conv_mode is not None and conv_mode != args.conv_mode: 83 | print( 84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 85 | conv_mode, args.conv_mode, args.conv_mode 86 | ) 87 | ) 88 | else: 89 | args.conv_mode = conv_mode 90 | 91 | conv = conv_templates[args.conv_mode].copy() 92 | conv.append_message(conv.roles[0], qs) 93 | conv.append_message(conv.roles[1], None) 94 | prompt = conv.get_prompt() 95 | 96 | image_files = image_parser(args) 97 | images = load_images(image_files) 98 | images_tensor = process_images( 99 | images, 100 | image_processor, 101 | model.config 102 | ).to(model.device, dtype=torch.float16) 103 | 104 | input_ids = ( 105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 106 | .unsqueeze(0) 107 | .cuda() 108 | ) 109 | 110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 111 | keywords = [stop_str] 112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | do_sample=True if args.temperature > 0 else False, 119 | temperature=args.temperature, 120 | top_p=args.top_p, 121 | num_beams=args.num_beams, 122 | max_new_tokens=args.max_new_tokens, 123 | use_cache=True, 124 | stopping_criteria=[stopping_criteria], 125 | ) 126 | 127 | input_token_len = input_ids.shape[1] 128 | # n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 129 | # if n_diff_input_output > 0: 130 | # print( 131 | # f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 132 | # ) 133 | # outputs = tokenizer.batch_decode( 134 | # output_ids[:, input_token_len:], skip_special_tokens=True 135 | # )[0] 136 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 137 | outputs = outputs.strip() 138 | if outputs.endswith(stop_str): 139 | outputs = outputs[: -len(stop_str)] 140 | outputs = outputs.strip() 141 | print(outputs) 142 | 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 147 | parser.add_argument("--model-base", type=str, default=None) 148 | parser.add_argument("--image-file", type=str, required=True) 149 | parser.add_argument("--query", type=str, required=True) 150 | parser.add_argument("--conv-mode", type=str, default=None) 151 | parser.add_argument("--sep", type=str, default=",") 152 | parser.add_argument("--temperature", type=float, default=0.2) 153 | parser.add_argument("--top_p", type=float, default=None) 154 | parser.add_argument("--num_beams", type=int, default=1) 155 | parser.add_argument("--max_new_tokens", type=int, default=512) 156 | args = parser.parse_args() 157 | 158 | eval_model(args) 159 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_phi3 import LlavaPhi3ForCausalLM, LlavaPhi3Config 5 | except: 6 | pass -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_llama", LlavaConfig) 158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) -------------------------------------------------------------------------------- /llava/model/language_model/llava_phi3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM 22 | from .modeling_phi3 import Phi3Config, Phi3Model, Phi3ForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaPhi3Config(Phi3Config): 31 | model_type = "llava_phi3" 32 | 33 | 34 | class LlavaPhi3Model(LlavaMetaModel, Phi3Model): 35 | config_class = LlavaPhi3Config 36 | 37 | def __init__(self, config: Phi3Config): 38 | super(LlavaPhi3Model, self).__init__(config) 39 | 40 | 41 | class LlavaPhi3ForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaPhi3Config 43 | 44 | def __init__(self, config): 45 | super(Phi3ForCausalLM, self).__init__(config) 46 | self.model = LlavaPhi3Model(config) 47 | self.vocab_size = config.vocab_size 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | image_sizes: Optional[List[List[int]]] = None, 69 | return_dict: Optional[bool] = None, 70 | ) -> Union[Tuple, CausalLMOutputWithPast]: 71 | 72 | if inputs_embeds is None: 73 | ( 74 | input_ids, 75 | position_ids, 76 | attention_mask, 77 | past_key_values, 78 | inputs_embeds, 79 | labels 80 | ) = self.prepare_inputs_labels_for_multimodal( 81 | input_ids, 82 | position_ids, 83 | attention_mask, 84 | past_key_values, 85 | labels, 86 | images, 87 | image_sizes 88 | ) 89 | 90 | return super().forward( 91 | input_ids=input_ids, 92 | attention_mask=attention_mask, 93 | position_ids=position_ids, 94 | past_key_values=past_key_values, 95 | inputs_embeds=inputs_embeds, 96 | labels=labels, 97 | use_cache=use_cache, 98 | output_attentions=output_attentions, 99 | output_hidden_states=output_hidden_states, 100 | return_dict=return_dict 101 | ) 102 | 103 | @torch.no_grad() 104 | def generate( 105 | self, 106 | inputs: Optional[torch.Tensor] = None, 107 | images: Optional[torch.Tensor] = None, 108 | image_sizes: Optional[torch.Tensor] = None, 109 | **kwargs, 110 | ) -> Union[GenerateOutput, torch.LongTensor]: 111 | position_ids = kwargs.pop("position_ids", None) 112 | attention_mask = kwargs.pop("attention_mask", None) 113 | if "inputs_embeds" in kwargs: 114 | raise NotImplementedError("`inputs_embeds` is not supported") 115 | 116 | if images is not None: 117 | ( 118 | inputs, 119 | position_ids, 120 | attention_mask, 121 | _, 122 | inputs_embeds, 123 | _ 124 | ) = self.prepare_inputs_labels_for_multimodal( 125 | inputs, 126 | position_ids, 127 | attention_mask, 128 | None, 129 | None, 130 | images, 131 | image_sizes=image_sizes 132 | ) 133 | else: 134 | inputs_embeds = self.get_model().embed_tokens(inputs) 135 | 136 | return super().generate( 137 | position_ids=position_ids, 138 | attention_mask=attention_mask, 139 | inputs_embeds=inputs_embeds, 140 | **kwargs 141 | ) 142 | 143 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 144 | inputs_embeds=None, **kwargs): 145 | images = kwargs.pop("images", None) 146 | image_sizes = kwargs.pop("image_sizes", None) 147 | inputs = super().prepare_inputs_for_generation( 148 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 149 | ) 150 | if images is not None: 151 | inputs['images'] = images 152 | if image_sizes is not None: 153 | inputs['image_sizes'] = image_sizes 154 | return inputs 155 | 156 | AutoConfig.register("llava_phi3", LlavaPhi3Config) 157 | AutoModelForCausalLM.register(LlavaPhi3Config, LlavaPhi3ForCausalLM) 158 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Cruise LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from .clip_encoder import CLIPVisionTower 17 | 18 | 19 | def build_vision_tower(vision_tower_cfg, **kwargs): 20 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 21 | is_absolute_path_exists = os.path.exists(vision_tower) 22 | if vision_tower == 'clip_4layers_336': 23 | from .clip_4layer_encoder import CLIPVisionTowerMultilayer 24 | return CLIPVisionTowerMultilayer(vision_tower, args=vision_tower_cfg, **kwargs) 25 | elif is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 26 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 27 | 28 | raise ValueError(f'Unknown vision tower: {vision_tower}') 29 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_4layer_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Cruise LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 21 | 22 | # Copyright 2023 Cruise LLC 23 | # 24 | # Licensed under the Apache License, Version 2.0 (the "License"); 25 | # you may not use this file except in compliance with the License. 26 | # You may obtain a copy of the License at 27 | # 28 | # http://www.apache.org/licenses/LICENSE-2.0 29 | # 30 | # Unless required by applicable law or agreed to in writing, software 31 | # distributed under the License is distributed on an "AS IS" BASIS, 32 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 33 | # See the License for the specific language governing permissions and 34 | # limitations under the License. 35 | 36 | 37 | 38 | class CLIPVisionTowerMultilayer(nn.Module): 39 | def __init__(self, vision_tower, args, delay_load=False): 40 | super().__init__() 41 | 42 | self.is_loaded = False 43 | 44 | self.vision_tower_name = vision_tower 45 | self.select_layer = args.mm_vision_select_layer 46 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 47 | 48 | if not delay_load: 49 | self.load_model() 50 | else: 51 | self.cfg_only = CLIPVisionConfig.from_pretrained('openai/clip-vit-large-patch14-336') 52 | 53 | def load_model(self): 54 | self.image_processor = CLIPImageProcessor.from_pretrained('openai/clip-vit-large-patch14-336') 55 | self.vision_tower = CLIPVisionModel.from_pretrained('openai/clip-vit-large-patch14-336') 56 | self.vision_tower.requires_grad_(False) 57 | 58 | self.is_loaded = True 59 | 60 | def feature_select(self, image_forward_outs): 61 | image_features = [image_forward_outs['hidden_states'][index][:, 1:] for index in [-2, -5, -8, -11, 6]] 62 | image_features = torch.cat(image_features, dim=-1) 63 | return image_features 64 | 65 | @torch.no_grad() 66 | def forward(self, images): 67 | if type(images) is list: 68 | image_features = [] 69 | for image in images: 70 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 71 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 72 | image_features.append(image_feature) 73 | else: 74 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 75 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 76 | 77 | return image_features 78 | 79 | @property 80 | def dummy_feature(self): 81 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 82 | 83 | @property 84 | def dtype(self): 85 | return self.vision_tower.dtype 86 | 87 | @property 88 | def device(self): 89 | return self.vision_tower.device 90 | 91 | @property 92 | def config(self): 93 | if self.is_loaded: 94 | return self.vision_tower.config 95 | else: 96 | return self.cfg_only 97 | 98 | @property 99 | def hidden_size(self): 100 | return self.config.hidden_size * 5 101 | 102 | @property 103 | def num_patches(self): 104 | return (self.config.image_size // self.config.patch_size) ** 2 105 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Cruise LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | 17 | import torch 18 | import torch.nn as nn 19 | import re 20 | 21 | 22 | class IdentityMap(nn.Module): 23 | def __init__(self): 24 | super().__init__() 25 | 26 | def forward(self, x, *args, **kwargs): 27 | return x 28 | 29 | @property 30 | def config(self): 31 | return {"mm_projector_type": 'identity'} 32 | 33 | 34 | class SimpleResBlock(nn.Module): 35 | def __init__(self, channels): 36 | super().__init__() 37 | self.pre_norm = nn.LayerNorm(channels) 38 | 39 | self.proj = nn.Sequential( 40 | nn.Linear(channels, channels), 41 | nn.GELU(), 42 | nn.Linear(channels, channels) 43 | ) 44 | def forward(self, x): 45 | x = self.pre_norm(x) 46 | return x + self.proj(x) 47 | 48 | 49 | class SimpleFeatureSingleModel(nn.Module): 50 | def __init__(self, num_clip_layers_by_feature_dim, final_linear): 51 | super(SimpleFeatureSingleModel, self).__init__() 52 | 53 | self.clip_layernorm = nn.LayerNorm(num_clip_layers_by_feature_dim) 54 | self.final_linear = final_linear 55 | 56 | def forward(self, clip_features): 57 | v1_sum = self.clip_layernorm(clip_features) 58 | v_hat = self.final_linear(v1_sum) 59 | return v_hat 60 | 61 | 62 | 63 | def build_vision_projector(config, delay_load=False, **kwargs): 64 | projector_type = getattr(config, 'mm_projector_type', 'linear') 65 | 66 | if projector_type == 'linear': 67 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 68 | 69 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 70 | if mlp_gelu_match: 71 | mlp_depth = int(mlp_gelu_match.group(1)) 72 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 73 | for _ in range(1, mlp_depth): 74 | modules.append(nn.GELU()) 75 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 76 | mlp = nn.Sequential(*modules) 77 | if config.mm_vision_tower == 'clip_4layers_336': 78 | return SimpleFeatureSingleModel(config.mm_hidden_size, mlp) 79 | else: 80 | return mlp 81 | if projector_type == 'identity': 82 | return IdentityMap() 83 | 84 | raise ValueError(f'Unknown projector type: {projector_type}') 85 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower() or 'vip-llava' in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | # Similar operation in model_worker.py 56 | image_tensor = process_images([image], image_processor, model.config) 57 | if type(image_tensor) is list: 58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 59 | else: 60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 61 | 62 | while True: 63 | try: 64 | inp = input(f"{roles[0]}: ") 65 | except EOFError: 66 | inp = "" 67 | if not inp: 68 | print("exit...") 69 | break 70 | 71 | print(f"{roles[1]}: ", end="") 72 | 73 | if image is not None: 74 | # first message 75 | if model.config.mm_use_im_start_end: 76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 77 | else: 78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 79 | conv.append_message(conv.roles[0], inp) 80 | image = None 81 | else: 82 | # later messages 83 | conv.append_message(conv.roles[0], inp) 84 | conv.append_message(conv.roles[1], None) 85 | prompt = conv.get_prompt() 86 | 87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 92 | 93 | with torch.inference_mode(): 94 | output_ids = model.generate( 95 | input_ids, 96 | images=image_tensor, 97 | do_sample=True if args.temperature > 0 else False, 98 | temperature=args.temperature, 99 | max_new_tokens=args.max_new_tokens, 100 | streamer=streamer, 101 | use_cache=True, 102 | stopping_criteria=[stopping_criteria]) 103 | 104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 105 | conv.messages[-1][-1] = outputs 106 | 107 | if args.debug: 108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-file", type=str, required=True) 116 | parser.add_argument("--device", type=str, default="cuda") 117 | parser.add_argument("--conv-mode", type=str, default=None) 118 | parser.add_argument("--temperature", type=float, default=0.2) 119 | parser.add_argument("--max-new-tokens", type=int, default=512) 120 | parser.add_argument("--load-8bit", action="store_true") 121 | parser.add_argument("--load-4bit", action="store_true") 122 | parser.add_argument("--debug", action="store_true") 123 | args = parser.parse_args() 124 | main(args) 125 | -------------------------------------------------------------------------------- /llava/serve/cli_vip.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | from llava.visual_prompt_generator import image_blending 17 | 18 | 19 | def load_image(image_file): 20 | if image_file.startswith('http://') or image_file.startswith('https://'): 21 | response = requests.get(image_file) 22 | image = Image.open(BytesIO(response.content)).convert('RGB') 23 | else: 24 | image = Image.open(image_file).convert('RGB') 25 | return image 26 | 27 | 28 | def main(args): 29 | # Model 30 | disable_torch_init() 31 | 32 | model_name = get_model_name_from_path(args.model_path) 33 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 34 | 35 | if 'llama-2' in model_name.lower(): 36 | conv_mode = "llava_llama_2" 37 | elif "v1" in model_name.lower() or 'vip-llava' in model_name.lower(): 38 | conv_mode = "llava_v1" 39 | elif "mpt" in model_name.lower(): 40 | conv_mode = "mpt" 41 | else: 42 | conv_mode = "llava_v0" 43 | 44 | if args.conv_mode is not None and conv_mode != args.conv_mode: 45 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 46 | else: 47 | args.conv_mode = conv_mode 48 | 49 | conv = conv_templates[args.conv_mode].copy() 50 | if "mpt" in model_name.lower(): 51 | roles = ('user', 'assistant') 52 | else: 53 | roles = conv.roles 54 | 55 | image = load_image(args.image_file) 56 | image = image_blending(image, args.shape, args.bbox, args.seg, image_processor.crop_size['height'], args.rgb, alpha = args.alpha, width = args.width) 57 | image.save('save_cli.png') 58 | print('Image saved as save_cli.png') 59 | # Similar operation in model_worker.py 60 | image_tensor = process_images([image], image_processor, model.config) 61 | if type(image_tensor) is list: 62 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 63 | else: 64 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 65 | 66 | while True: 67 | try: 68 | inp = input(f"{roles[0]}: ") 69 | except EOFError: 70 | inp = "" 71 | if not inp: 72 | print("exit...") 73 | break 74 | 75 | print(f"{roles[1]}: ", end="") 76 | 77 | if image is not None: 78 | # first message 79 | if model.config.mm_use_im_start_end: 80 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 81 | else: 82 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 83 | conv.append_message(conv.roles[0], inp) 84 | image = None 85 | else: 86 | # later messages 87 | conv.append_message(conv.roles[0], inp) 88 | conv.append_message(conv.roles[1], None) 89 | prompt = conv.get_prompt() 90 | 91 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 92 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 93 | keywords = [stop_str] 94 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 95 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=image_tensor, 101 | do_sample=True if args.temperature > 0 else False, 102 | temperature=args.temperature, 103 | max_new_tokens=args.max_new_tokens, 104 | streamer=streamer, 105 | use_cache=True, 106 | stopping_criteria=[stopping_criteria]) 107 | 108 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 109 | conv.messages[-1][-1] = outputs 110 | if args.debug: 111 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 112 | 113 | 114 | def parse_list(string): 115 | if string != None: 116 | string = string.split(',') 117 | for i in range(len(string)): 118 | string[i] = int(string[i]) 119 | return string 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 124 | parser.add_argument("--model-base", type=str, default=None) 125 | parser.add_argument("--image-file", type=str, required=True) 126 | parser.add_argument("--device", type=str, default="cuda") 127 | parser.add_argument("--conv-mode", type=str, default=None) 128 | parser.add_argument("--temperature", type=float, default=0.2) 129 | parser.add_argument("--max-new-tokens", type=int, default=512) 130 | parser.add_argument("--load-8bit", action="store_true") 131 | parser.add_argument("--load-4bit", action="store_true") 132 | parser.add_argument("--debug", action="store_true") 133 | parser.add_argument('--bbox', type=str, default=None, help='Bounding box of the image in the format x1,y1,x2,y2, topleft first, then bottomright.') 134 | parser.add_argument('--seg', type=str, default=None, help='Segmentation mask of the image in the format x1,y1,x2,y2, ..., xn,yn.') 135 | parser.add_argument('--rgb', type=str, default='255,0,0', help='rgb value') 136 | parser.add_argument('--width', type=int, default=3, help='line width') 137 | parser.add_argument('--shape', type=str, default='rectangle', help='Options are "rectangle", "ellipse","triangle", "point", "scribble" , "mask contour" , "mask", "arrow"') 138 | parser.add_argument('--alpha', type=int, default=128, help='line width') 139 | 140 | 141 | 142 | 143 | 144 | args = parser.parse_args() 145 | args.bbox = parse_list(args.bbox) 146 | args.rgb = tuple(parse_list(args.rgb)) 147 | main(args) 148 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WisconsinAIVision/ViP-LLaVA/77079b1e1990472598f1f8a6bb3f7f470100d191/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | # from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | # replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | # if __name__ == "__main__": 13 | # train() 14 | 15 | if __name__ == "__main__": 16 | train(attn_implementation="flash_attention_2") -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /playground/data/prompts/refcocog.text: -------------------------------------------------------------------------------- 1 | As a visual AI, engage in an image-based dialogue about an object highlighted by a special feature without referencing the feature itself. Answer questions about the image, which comes in two versions: one original and one with a particular area of interest indicated. 2 | 3 | Inquire about the specifics of the highlighted object—its characteristics, actions, or context—ensuring clarity and precision in the questions and answers. 4 | 5 | For deeper analysis, prompt complex questions that demand additional knowledge or analytical reasoning about the scene, and provide comprehensive responses. 6 | 7 | When constructing the dialogue, use the term "" only as an identifier for the specific area of interest and avoid discussing the feature that highlights this area. Maintain strict adherence to this rule throughout the conversation. 8 | 9 | The dialogue should follow a JSON structure with an initial 'human' question and alternate with 'gpt' answers, exemplified by: 10 | [ {"human":"Describe the object ."}, {"gpt":"The object is characterized by..."}, {"human":"What behavioral patterns might be observed in the animal ?"}, {"gpt":"The animal may exhibit behaviors such as..."} (responses in the range of 50~200 words), ... ]> 11 | 12 | Avoid any mention of a base description that are not visible to the user. -------------------------------------------------------------------------------- /playground/data/prompts/vg.text: -------------------------------------------------------------------------------- 1 | As a visual assistant, your task is to engage in a dialogue about specific elements within an image, referred to as "subject" and "object." These elements are indicated within the image by undisclosed visual cues, with the subject are always identified by "" and the object always "". You are to analyze and discuss these elements based on their attributes, actions, and the context in which they are placed. You will be provided with two images: the original and another with the areas of interest indicated. 2 | 3 | Your responses should focus on the characteristics and dynamics of the subject and object, their interrelationships, and their interactions with their environment. Craft diverse and complex questions and respond using the placeholders "" for the subject and "" for the object to reference the highlighted elements. Such placeholders **must be used with a noun** to contextualize what it is referring to. 4 | 5 | The dialogue should be structured in a JSON format, providing clear and precise questions and answers about the elements. These responses may require additional background knowledge or logical deduction to adequately inform the discussion. 6 | 7 | Here is an example of how the JSON dialogue might be structured correctly: 8 | [ 9 | {"human":"Describe the instance ."}, 10 | {"gpt":"The instance is a duck, characterized by..."}, 11 | {"human":"How does the duck relate to the object ?"}, 12 | {"gpt":"This duck is swimming to the ship ..."}, 13 | {"human":"What behavioral patterns might be observed for the this duck?"}, 14 | {"gpt": This duck could ... (responses should be within the range of 50~200 words) 15 | ] 16 | In our conversation, please avoid using the term 'subject' to describe the instance . Let's proceed with the understanding that multiple topics may be discussed. 17 | You will be given a triplet description consisting of a subject, a relationship, and an object from the Visual Genome dataset annotation. Please avoid any direct mention of the base description or visual cues that are not apparent to the user. -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vip-llava" 7 | version = "1.2.2.post1" 8 | description = "Making Large Multimodal Models Understand Arbitrary Visual Prompts." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2", "torchvision==0.16.2", 17 | "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.21.0", "peft", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.16.0", "gradio_client==0.8.1", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "openai==0.28", "shapely", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 27 | build = ["build", "twine"] 28 | 29 | [project.urls] 30 | "Homepage" = "https://vip-llava.github.io" 31 | "Bug Tracker" = "https://github.com/mu-cai/ViP-LLaVA/issues" 32 | 33 | [tool.setuptools.packages.find] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | -------------------------------------------------------------------------------- /scripts/convert_vipbench_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | cur_result = {} 11 | 12 | for line in open(args.src): 13 | data = json.loads(line) 14 | qid = data['question_id'] 15 | cur_result[f'v1_{qid}'] = data['text'] 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(cur_result, f, indent=2) 19 | -------------------------------------------------------------------------------- /scripts/eval/pointQA.sh: -------------------------------------------------------------------------------- 1 | model_name=mucai/vip-llava-13b 2 | eval_dataset=pointQA_twice_test 3 | python llava/eval/model_vqa_loader_vip.py \ 4 | --model-path $model_name \ 5 | --question-file ./dataset/$eval_dataset.json \ 6 | --image-folder ./dataset \ 7 | --alpha 128 \ 8 | --answers-file ./playground/data/eval/$eval_dataset-$model_name.json 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /scripts/eval/v7w.sh: -------------------------------------------------------------------------------- 1 | 2 | model_name=mucai/vip-llava-13b 3 | eval_dataset=v7w-test 4 | python llava/eval/model_vqa_loader_vip.py \ 5 | --model-path $model_name \ 6 | --question-file ./dataset/$eval_dataset.json \ 7 | --image-folder ./dataset \ 8 | --alpha 128 \ 9 | --answers-file ./playground/data/eval/$eval_dataset-$model_name-alpha128.json 10 | -------------------------------------------------------------------------------- /scripts/eval/vcr_qa.sh: -------------------------------------------------------------------------------- 1 | model_name=mucai/vip-llava-7b-base-vcr-ft 2 | eval_dataset=vcr-val 3 | python llava/eval/model_vqa_loader_vip.py \ 4 | --model-path $model_name \ 5 | --question-file ./dataset/$eval_dataset.json \ 6 | --image-folder ./dataset \ 7 | --alpha 128 \ 8 | --visual_prompt_style vcr_qa \ 9 | --image_aspect_ratio resize \ 10 | --answers-file ./playground/data/eval/$eval_dataset-qa-$model_name.json 11 | -------------------------------------------------------------------------------- /scripts/eval/vcr_qar.sh: -------------------------------------------------------------------------------- 1 | model_name=mucai/vip-llava-7b-base-vcr-ft 2 | eval_dataset=vcr-val 3 | python llava/eval/model_vqa_loader_vip.py \ 4 | --model-path $model_name \ 5 | --question-file ./dataset/$eval_dataset.json \ 6 | --image-folder ./dataset \ 7 | --alpha 128 \ 8 | --visual_prompt_style vcr_qar \ 9 | --image_aspect_ratio resize \ 10 | --answers-file ./playground/data/eval/$eval_dataset-qar-$model_name.json 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /scripts/eval/vipbench.sh: -------------------------------------------------------------------------------- 1 | model_name=vip-llava-7b 2 | model_path=mucai/$model_name 3 | folder=ViP-Bench 4 | split=$1 5 | mkdir -p ./playground/data/eval/$folder/answers 6 | python -m llava.eval.model_vqa \ 7 | --model-path $model_path \ 8 | --question-file ./playground/data/eval/$folder/$split/questions.jsonl \ 9 | --image-folder ./playground/data/eval/$folder/$split/images \ 10 | --answers-file ./playground/data/eval/$folder/answers/$model_name-$split.jsonl \ 11 | --temperature 0 12 | 13 | mkdir -p ./playground/data/eval/$folder/results 14 | 15 | python scripts/convert_vipbench_for_eval.py \ 16 | --src ./playground/data/eval/$folder/answers/$model_name-$split.jsonl \ 17 | --dst ./playground/data/eval/$folder/results/$model_name-$split.json -------------------------------------------------------------------------------- /scripts/finetune_llava_1_5_llama3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_llama_3 3 | DATA_ROOT=./dataset 4 | MODEL_VERSION='Meta-Llama-3-8B-Instruct' 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero3.json \ 9 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/llava_v1_5_mix665k.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower openai/clip-vit-large-patch14-336 \ 14 | --pretrain_mm_mlp_adapter ./checkpoints/pretrain_llava_1_5_$MODEL_VERSION/mm_projector.bin \ 15 | --mm_projector_type mlp2x_gelu \ 16 | --mm_vision_select_layer -2 \ 17 | --mm_use_im_start_end False \ 18 | --mm_use_im_patch_token False \ 19 | --image_aspect_ratio pad \ 20 | --bf16 True \ 21 | --output_dir ./checkpoints/llava-1.5-llama-3-8b \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 8 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 2 \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 50000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 2e-5 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb -------------------------------------------------------------------------------- /scripts/finetune_llava_1_5_phi3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_VERSION=Phi-3-mini-4k-instruct 3 | 4 | deepspeed llava/train/train_mem.py --deepspeed ./scripts/zero3.json \ 5 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 6 | --version llava_phi_3 \ 7 | --data_path ./dataset/llava_v1_5_mix665k.json \ 8 | --image_folder ./dataset \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --pretrain_mm_mlp_adapter ./checkpoints/pretrain_llava_1_5_$MODEL_VERSION/mm_projector.bin \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir ./checkpoints/llava-1.5-phi-3-mini-3.8B \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 16 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /scripts/finetune_stage2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_v1 3 | DATA_ROOT=./playground/data 4 | model_size=7b 5 | MODEL_VERSION=vicuna-$model_size-v1-5 6 | 7 | 8 | deepspeed --master_port 12347 llava/train/train_mem.py \ 9 | --deepspeed ./scripts/zero2.json \ 10 | --model_name_or_path lmsys/$MODEL_VERSION \ 11 | --version $PROMPT_VERSION \ 12 | --data_path $DATA_ROOT/vip-llava_stage2_mix.json \ 13 | --image_folder $DATA_ROOT \ 14 | --vision_tower clip_4layers_336 \ 15 | --pretrain_mm_mlp_adapter ./checkpoints/vip-llava-$model_size-pretrain/mm_projector.bin \ 16 | --mm_projector_type mlp2x_gelu \ 17 | --mm_vision_select_layer -2 \ 18 | --mm_use_im_start_end False \ 19 | --mm_use_im_patch_token False \ 20 | --image_aspect_ratio pad \ 21 | --bf16 True \ 22 | --output_dir ./checkpoints/vip-llava-$model_size-stage2-ft \ 23 | --num_train_epochs 1 \ 24 | --per_device_train_batch_size 16 \ 25 | --per_device_eval_batch_size 4 \ 26 | --gradient_accumulation_steps 1 \ 27 | --evaluation_strategy "no" \ 28 | --save_strategy "steps" \ 29 | --save_steps 50000 \ 30 | --save_total_limit 1 \ 31 | --learning_rate 2e-5 \ 32 | --weight_decay 0. \ 33 | --warmup_ratio 0.03 \ 34 | --lr_scheduler_type "cosine" \ 35 | --logging_steps 1 \ 36 | --tf32 True \ 37 | --model_max_length 2048 \ 38 | --gradient_checkpointing True \ 39 | --dataloader_num_workers 4 \ 40 | --lazy_preprocess True \ 41 | --report_to wandb 42 | -------------------------------------------------------------------------------- /scripts/finetune_stage2_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_v1 3 | DATA_ROOT=./playground/data 4 | model_size=7b 5 | MODEL_VERSION=vicuna-$model_size-v1-5 6 | 7 | deepspeed llava/train/train_mem.py \ 8 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 9 | --deepspeed ./scripts/zero2.json \ 10 | --model_name_or_path lmsys/$MODEL_VERSION \ 11 | --version $PROMPT_VERSION \ 12 | --data_path $DATA_ROOT/vip-llava_stage2_mix.json \ 13 | --image_folder $DATA_ROOT \ 14 | --vision_tower clip_4layers_336 \ 15 | --pretrain_mm_mlp_adapter ./checkpoints/vip-llava-$model_size-pretrain/mm_projector.bin \ 16 | --mm_projector_type mlp2x_gelu \ 17 | --mm_vision_select_layer -2 \ 18 | --mm_use_im_start_end False \ 19 | --mm_use_im_patch_token False \ 20 | --image_aspect_ratio pad \ 21 | --bf16 True \ 22 | --output_dir ./checkpoints/vip-llava-$model_size-stage2-lora \ 23 | --num_train_epochs 1 \ 24 | --per_device_train_batch_size 16 \ 25 | --per_device_eval_batch_size 4 \ 26 | --gradient_accumulation_steps 1 \ 27 | --evaluation_strategy "no" \ 28 | --save_strategy "steps" \ 29 | --save_steps 50000 \ 30 | --save_total_limit 1 \ 31 | --learning_rate 2e-4 \ 32 | --weight_decay 0. \ 33 | --warmup_ratio 0.03 \ 34 | --lr_scheduler_type "cosine" \ 35 | --logging_steps 1 \ 36 | --tf32 True \ 37 | --model_max_length 2048 \ 38 | --gradient_checkpointing True \ 39 | --dataloader_num_workers 4 \ 40 | --lazy_preprocess True \ 41 | --report_to wandb 42 | -------------------------------------------------------------------------------- /scripts/finetune_stage3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_VERSION="vicuna-7b-v1-5" 3 | PROMPT_VERSION=llava_v1 4 | DATA_ROOT=./playground/data 5 | model_size=7b 6 | 7 | 8 | deepspeed --master_port 12347 llava/train/train_mem.py \ 9 | --deepspeed ./scripts/zero2.json \ 10 | --model_name_or_path ./checkpoints/vip-llava-$model_size-stage2-ft \ 11 | --version $PROMPT_VERSION \ 12 | --data_path $DATA_ROOT/vip-llava_stage3_mix.json \ 13 | --image_folder $DATA_ROOT \ 14 | --vision_tower clip_4layers_336 \ 15 | --mm_projector_type mlp2x_gelu \ 16 | --mm_vision_select_layer -2 \ 17 | --mm_use_im_start_end False \ 18 | --mm_use_im_patch_token False \ 19 | --image_aspect_ratio pad \ 20 | --bf16 True \ 21 | --output_dir ./checkpoints/vip-llava-$model_size-stage3-ft \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 16 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 1 \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 50000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 2e-7 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb 41 | -------------------------------------------------------------------------------- /scripts/finetune_task.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_v1 3 | DATA_ROOT=./playground/data 4 | model_size=7b 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero2.json \ 9 | --model_name_or_path mucai/vip-llava-$model_size \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage3_mix-task.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --mm_projector_type mlp2x_gelu \ 15 | --mm_vision_select_layer -2 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --image_aspect_ratio pad \ 19 | --bf16 True \ 20 | --output_dir ./checkpoints/vip-llava-$model_size-task-ft \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 16 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 1 \ 25 | --evaluation_strategy "no" \ 26 | --save_strategy "steps" \ 27 | --save_steps 50000 \ 28 | --save_total_limit 1 \ 29 | --learning_rate 2e-5 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 True \ 35 | --model_max_length 2048 \ 36 | --gradient_checkpointing True \ 37 | --dataloader_num_workers 4 \ 38 | --lazy_preprocess True \ 39 | --report_to wandb 40 | -------------------------------------------------------------------------------- /scripts/finetune_task_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_v1 3 | DATA_ROOT=./playground/data 4 | model_size=7b 5 | 6 | deepspeed llava/train/train_mem.py \ 7 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 8 | --deepspeed ./scripts/zero2.json \ 9 | --model_name_or_path mucai/vip-llava-$model_size \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage3_mix-task.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --mm_projector_type mlp2x_gelu \ 15 | --mm_vision_select_layer -2 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --image_aspect_ratio pad \ 19 | --bf16 True \ 20 | --output_dir ./checkpoints/vip-llava-$model_size-task-lora \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 16 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 1 \ 25 | --evaluation_strategy "no" \ 26 | --save_strategy "steps" \ 27 | --save_steps 50000 \ 28 | --save_total_limit 1 \ 29 | --learning_rate 2e-4 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 True \ 35 | --model_max_length 2048 \ 36 | --gradient_checkpointing True \ 37 | --dataloader_num_workers 4 \ 38 | --lazy_preprocess True \ 39 | --report_to wandb 40 | -------------------------------------------------------------------------------- /scripts/finetune_vip_llava_llama3_stage2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_llama_3 3 | DATA_ROOT=./dataset 4 | MODEL_VERSION='Meta-Llama-3-8B-Instruct' 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero3.json \ 9 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage2_mix.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --pretrain_mm_mlp_adapter ./checkpoints/vip-llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 15 | --mm_projector_type mlp2x_gelu \ 16 | --mm_vision_select_layer -2 \ 17 | --mm_use_im_start_end False \ 18 | --mm_use_im_patch_token False \ 19 | --image_aspect_ratio pad \ 20 | --bf16 True \ 21 | --output_dir ./checkpoints/vip-llava-$MODEL_VERSION-stage2-ft \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 8 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 2 \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 50000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 2e-5 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb -------------------------------------------------------------------------------- /scripts/finetune_vip_llava_llama3_stage3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_llama_3 3 | DATA_ROOT=./dataset 4 | MODEL_VERSION='Meta-Llama-3-8B-Instruct' 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero3.json \ 9 | --model_name_or_path ./checkpoints/vip-llava-$MODEL_VERSION-stage2-ft \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage3_mix.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --mm_projector_type mlp2x_gelu \ 15 | --mm_vision_select_layer -2 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --image_aspect_ratio pad \ 19 | --bf16 True \ 20 | --output_dir ./checkpoints/vip-llava-llama-3-8b \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 8 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 2 \ 25 | --evaluation_strategy "no" \ 26 | --save_strategy "steps" \ 27 | --save_steps 50000 \ 28 | --save_total_limit 1 \ 29 | --learning_rate 2e-7 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 True \ 35 | --model_max_length 2048 \ 36 | --gradient_checkpointing True \ 37 | --dataloader_num_workers 4 \ 38 | --lazy_preprocess True \ 39 | --report_to wandb -------------------------------------------------------------------------------- /scripts/finetune_vip_llava_phi3_stage2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_phi_3 3 | DATA_ROOT=./dataset 4 | MODEL_VERSION='Phi-3-mini-4k-instruct' 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero3.json \ 9 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage2_mix.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --pretrain_mm_mlp_adapter ./checkpoints/vip-llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 15 | --mm_projector_type mlp2x_gelu \ 16 | --mm_vision_select_layer -2 \ 17 | --mm_use_im_start_end False \ 18 | --mm_use_im_patch_token False \ 19 | --image_aspect_ratio pad \ 20 | --bf16 True \ 21 | --output_dir ./checkpoints/vip-llava-$MODEL_VERSION-stage2-ft \ 22 | --num_train_epochs 1 \ 23 | --per_device_train_batch_size 8 \ 24 | --per_device_eval_batch_size 4 \ 25 | --gradient_accumulation_steps 2 \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 50000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 2e-5 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb -------------------------------------------------------------------------------- /scripts/finetune_vip_llava_phi3_stage3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | PROMPT_VERSION=llava_phi_3 3 | DATA_ROOT=./dataset 4 | MODEL_VERSION='Phi-3-mini-4k-instruct' 5 | 6 | 7 | deepspeed --master_port 12347 llava/train/train_mem.py \ 8 | --deepspeed ./scripts/zero3.json \ 9 | --model_name_or_path ./checkpoints/vip-llava-$MODEL_VERSION-stage2-ft \ 10 | --version $PROMPT_VERSION \ 11 | --data_path $DATA_ROOT/vip-llava_stage3_mix.json \ 12 | --image_folder $DATA_ROOT \ 13 | --vision_tower clip_4layers_336 \ 14 | --mm_projector_type mlp2x_gelu \ 15 | --mm_vision_select_layer -2 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --image_aspect_ratio pad \ 19 | --bf16 True \ 20 | --output_dir ./checkpoints/vip-llava-1.5-phi-3-mini-3.8B \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 8 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 2 \ 25 | --evaluation_strategy "no" \ 26 | --save_strategy "steps" \ 27 | --save_steps 50000 \ 28 | --save_total_limit 1 \ 29 | --learning_rate 2e-7 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 True \ 35 | --model_max_length 2048 \ 36 | --gradient_checkpointing True \ 37 | --dataloader_num_workers 4 \ 38 | --lazy_preprocess True \ 39 | --report_to wandb -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_size=7b 3 | 4 | deepspeed llava/train/train_mem.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path lmsys/vicuna-$model_size-v1.5 \ 7 | --version plain \ 8 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 9 | --image_folder ./playground/data/LLaVA-Pretrain/images \ 10 | --vision_tower clip_4layers_336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --tune_mm_mlp_adapter True \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/vip-llava-$model_size-pretrain \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 32 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 24000 \ 25 | --save_total_limit 1 \ 26 | --learning_rate 1e-3 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --tf32 True \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 4 \ 35 | --lazy_preprocess True \ 36 | --report_to wandb 37 | -------------------------------------------------------------------------------- /scripts/pretrain_llava_1_5_llama3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_VERSION='Meta-Llama-3-8B-Instruct' 4 | 5 | deepspeed llava/train/train_mem.py --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 7 | --version plain \ 8 | --data_path ./dataset/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 9 | --image_folder ./dataset/LLaVA-Pretrain/images \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --tune_mm_mlp_adapter True \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/pretrain_llava_1_5_$MODEL_VERSION \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 32 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 24000 \ 25 | --save_total_limit 1 \ 26 | --learning_rate 1e-3 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --tf32 True \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 4 \ 35 | --lazy_preprocess True \ 36 | --report_to wandb 37 | -------------------------------------------------------------------------------- /scripts/pretrain_llava_1_5_phi3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_VERSION=Phi-3-mini-4k-instruct 3 | 4 | deepspeed llava/train/train_mem.py --deepspeed ./scripts/zero2.json \ 5 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 6 | --version plain \ 7 | --data_path ./dataset/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 8 | --image_folder ./dataset/LLaVA-Pretrain/images \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --tune_mm_mlp_adapter True \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/pretrain_llava_1_5_$MODEL_VERSION \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 32 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb -------------------------------------------------------------------------------- /scripts/pretrain_vip_llava_llama3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_VERSION='Meta-Llama-3-8B-Instruct' 3 | 4 | deepspeed llava/train/train_mem.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 7 | --version plain \ 8 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 9 | --image_folder ./playground/data/LLaVA-Pretrain/images \ 10 | --vision_tower clip_4layers_336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --tune_mm_mlp_adapter True \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/vip-llava-$MODEL_VERSION-pretrain \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 32 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 24000 \ 25 | --save_total_limit 1 \ 26 | --learning_rate 1e-3 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --tf32 True \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 4 \ 35 | --lazy_preprocess True \ 36 | --report_to wandb 37 | -------------------------------------------------------------------------------- /scripts/pretrain_vip_llava_phi3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_VERSION='Phi-3-mini-4k-instruct' 3 | 4 | 5 | 6 | deepspeed llava/train/train_mem.py \ 7 | --deepspeed ./scripts/zero2.json \ 8 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 9 | --version plain \ 10 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 11 | --image_folder ./playground/data/LLaVA-Pretrain/images \ 12 | --vision_tower clip_4layers_336 \ 13 | --mm_projector_type mlp2x_gelu \ 14 | --tune_mm_mlp_adapter True \ 15 | --mm_vision_select_layer -2 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --bf16 True \ 19 | --output_dir ./checkpoints/vip-llava-$MODEL_VERSION-pretrain \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 32 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --evaluation_strategy "no" \ 25 | --save_strategy "steps" \ 26 | --save_steps 24000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 1e-3 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --lr_scheduler_type "cosine" \ 32 | --logging_steps 1 \ 33 | --tf32 True \ 34 | --model_max_length 2048 \ 35 | --gradient_checkpointing True \ 36 | --dataloader_num_workers 4 \ 37 | --lazy_preprocess True \ 38 | --report_to wandb -------------------------------------------------------------------------------- /scripts/v1_5/eval/gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | 8 | # CKPT="llava-v1.5-13b" 9 | model_name=$1 10 | model_name_replace=${model_name//\//_} 11 | echo $model_name_replace 12 | CKPT=$model_name_replace 13 | 14 | SPLIT="llava_gqa_testdev_balanced" 15 | GQADIR="./playground/data/eval/gqa/data" 16 | 17 | for IDX in $(seq 0 $((CHUNKS-1))); do 18 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 19 | --model-path ./checkpoints/$model_name_replace \ 20 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 21 | --image-folder ./playground/data/eval/gqa/data/images \ 22 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 23 | --num-chunks $CHUNKS \ 24 | --chunk-idx $IDX \ 25 | --temperature 0 \ 26 | --conv-mode vicuna_v1 & 27 | done 28 | 29 | wait 30 | 31 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl 32 | 33 | # Clear out the output file if it exists. 34 | > "$output_file" 35 | 36 | # Loop through the indices and concatenate each file. 37 | for IDX in $(seq 0 $((CHUNKS-1))); do 38 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 39 | done 40 | 41 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json 42 | 43 | cd $GQADIR 44 | python eval/eval.py --tier testdev_balanced 45 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/llavabench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | model_name=$1 5 | model_name_replace=${model_name//\//_} 6 | echo $model_name_replace 7 | 8 | python -m llava.eval.model_vqa \ 9 | --model-path ./checkpoints/$model_name_replace \ 10 | --question-file ./playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ 11 | --image-folder ./playground/data/eval/llava-bench-in-the-wild/images \ 12 | --answers-file ./playground/data/eval/llava-bench-in-the-wild/answers/$model_name_replace.jsonl \ 13 | --temperature 0 \ 14 | --conv-mode vicuna_v1 15 | 16 | mkdir -p playground/data/eval/llava-bench-in-the-wild/reviews 17 | 18 | python llava/eval/eval_gpt_review_bench.py \ 19 | --question playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ 20 | --context playground/data/eval/llava-bench-in-the-wild/context.jsonl \ 21 | --rule llava/eval/table/rule.json \ 22 | --answer-list \ 23 | playground/data/eval/llava-bench-in-the-wild/answers_gpt4.jsonl \ 24 | playground/data/eval/llava-bench-in-the-wild/answers/$model_name_replace.jsonl \ 25 | --output \ 26 | playground/data/eval/llava-bench-in-the-wild/reviews/$model_name_replace.jsonl 27 | 28 | python llava/eval/summarize_gpt_review.py -f playground/data/eval/llava-bench-in-the-wild/reviews/$model_name_replace.jsonl 29 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name=$1 4 | model_name_replace=${model_name//\//_} 5 | echo $model_name_replace 6 | 7 | 8 | SPLIT="mmbench_dev_20230712" 9 | 10 | python -m llava.eval.model_vqa_mmbench \ 11 | --model-path ./checkpoints/$model_name_replace \ 12 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 13 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/$model_name_replace.jsonl \ 14 | --single-pred-prompt \ 15 | --temperature 0 \ 16 | --conv-mode vicuna_v1 17 | 18 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 19 | 20 | python scripts/convert_mmbench_for_submission.py \ 21 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 22 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \ 23 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \ 24 | --experiment $model_name_replace 25 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmbench_cn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT="mmbench_dev_cn_20231003" 4 | model_name=$1 5 | model_name_replace=${model_name//\//_} 6 | echo $model_name_replace 7 | 8 | python -m llava.eval.model_vqa_mmbench \ 9 | --model-path ./checkpoints/$model_name_replace \ 10 | --question-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ 11 | --answers-file ./playground/data/eval/mmbench_cn/answers/$SPLIT/$model_name_replace.jsonl \ 12 | --lang cn \ 13 | --single-pred-prompt \ 14 | --temperature 0 \ 15 | --conv-mode vicuna_v1 16 | 17 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 18 | 19 | python scripts/convert_mmbench_for_submission.py \ 20 | --annotation-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ 21 | --result-dir ./playground/data/eval/mmbench_cn/answers/$SPLIT \ 22 | --upload-dir ./playground/data/eval/mmbench_cn/answers_upload/$SPLIT \ 23 | --experiment $model_name_replace 24 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name=$1 4 | model_name_replace=${model_name//\//_} 5 | echo $model_name_replace 6 | 7 | 8 | python -m llava.eval.model_vqa_loader \ 9 | --model-path ./checkpoints/$model_name \ 10 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 11 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 12 | --answers-file ./playground/data/eval/MME/answers/$model_name_replace.jsonl \ 13 | --temperature 0 \ 14 | --conv-mode vicuna_v1 15 | 16 | cd ./playground/data/eval/MME 17 | 18 | python convert_answer_to_mme.py --experiment $model_name_replace 19 | 20 | cd eval_tool 21 | 22 | python calculation.py --results_dir answers/$model_name_replace 23 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | model_name_replace=${model_name//\//_} 4 | echo $model_name_replace 5 | python -m llava.eval.model_vqa \ 6 | --model-path ./checkpoints/$model_name_replace \ 7 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 8 | --image-folder ./playground/data/eval/mm-vet/images \ 9 | --answers-file ./playground/data/eval/mm-vet/answers/$model_name_replace.jsonl \ 10 | --temperature 0 \ 11 | --conv-mode vicuna_v1 12 | 13 | mkdir -p ./playground/data/eval/mm-vet/results 14 | 15 | python scripts/convert_mmvet_for_eval.py \ 16 | --src ./playground/data/eval/mm-vet/answers/$model_name_replace.jsonl \ 17 | --dst ./playground/data/eval/mm-vet/results/$model_name_replace.json 18 | 19 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/pope.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name=$1 4 | model_name_replace=${model_name//\//_} 5 | echo $model_name_replace 6 | 7 | python -m llava.eval.model_vqa_loader \ 8 | --model-path ./checkpoints/$model_name \ 9 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 10 | --image-folder ./playground/data/eval/pope/val2014 \ 11 | --answers-file ./playground/data/eval/pope/answers/$model_name_replace.jsonl \ 12 | --temperature 0 \ 13 | --conv-mode vicuna_v1 14 | 15 | python llava/eval/eval_pope.py \ 16 | --annotation-dir ./playground/data/eval/pope/coco \ 17 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 18 | --result-file ./playground/data/eval/pope/answers/$model_name_replace.jsonl 19 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/qbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" = "dev" ]; then 4 | echo "Evaluating in 'dev' split." 5 | elif [ "$1" = "test" ]; then 6 | echo "Evaluating in 'test' split." 7 | else 8 | echo "Unknown split, please choose between 'dev' and 'test'." 9 | exit 1 10 | fi 11 | 12 | python -m llava.eval.model_vqa_qbench \ 13 | --model-path liuhaotian/llava-v1.5-13b \ 14 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ 15 | --questions-file ./playground/data/eval/qbench/llvisionqa_$1.json \ 16 | --answers-file ./playground/data/eval/qbench/llvisionqa_$1_answers.jsonl \ 17 | --conv-mode llava_v1 \ 18 | --lang en 19 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/qbench_zh.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" = "dev" ]; then 4 | ZH_SPLIT="验证集" 5 | echo "Evaluating in 'dev' split." 6 | elif [ "$1" = "test" ]; then 7 | ZH_SPLIT="测试集" 8 | echo "Evaluating in 'test' split." 9 | else 10 | echo "Unknown split, please choose between 'dev' and 'test'." 11 | exit 1 12 | fi 13 | 14 | python -m llava.eval.model_vqa_qbench \ 15 | --model-path liuhaotian/llava-v1.5-13b \ 16 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ 17 | --questions-file ./playground/data/eval/qbench/质衡-问答-$ZH_SPLIT.json \ 18 | --answers-file ./playground/data/eval/qbench/llvisionqa_zh_$1_answers.jsonl \ 19 | --conv-mode llava_v1 \ 20 | --lang zh 21 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/seed-img.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | model_name_replace=${model_name//\//_} 4 | echo $model_name_replace 5 | 6 | 7 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 8 | IFS=',' read -ra GPULIST <<< "$gpu_list" 9 | 10 | CHUNKS=${#GPULIST[@]} 11 | 12 | # CKPT="llava-v1.5-13b" 13 | CKPT=$model_name_replace # "llava-v1.5-13b" 14 | 15 | for IDX in $(seq 0 $((CHUNKS-1))); do 16 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 17 | --model-path ./checkpoints/$model_name_replace \ 18 | --question-file ./playground/data/eval/seed_bench/llava-seed-bench-imgonly.jsonl \ 19 | --image-folder ./playground/data/eval/seed_bench \ 20 | --answers-file ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl \ 21 | --num-chunks $CHUNKS \ 22 | --chunk-idx $IDX \ 23 | --temperature 0 \ 24 | --conv-mode vicuna_v1 & 25 | done 26 | 27 | wait 28 | 29 | output_file=./playground/data/eval/seed_bench/answers/$CKPT/merge.jsonl 30 | 31 | # Clear out the output file if it exists. 32 | > "$output_file" 33 | 34 | # Loop through the indices and concatenate each file. 35 | for IDX in $(seq 0 $((CHUNKS-1))); do 36 | cat ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 37 | done 38 | 39 | # Evaluate 40 | python scripts/convert_seed_for_submission.py \ 41 | --annotation-file ./playground/data/eval/seed_bench/SEED-Bench-imgonly.json \ 42 | --result-file $output_file \ 43 | --result-upload-file ./playground/data/eval/seed_bench/answers_upload/llava-v1.5-13b.jsonl 44 | 45 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/seed-process-anno.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | json_path = './playground/data/eval/seed_bench/SEED-Bench.json' 4 | with open(json_path, 'r') as f: 5 | data = json.load(f) 6 | 7 | data_new = {} 8 | data_new['question_type'] = data['question_type'] 9 | data_new['questions'] = [] 10 | # breakpoint() 11 | for question in data['questions']: 12 | if question['data_type'] == 'image': 13 | data_new['questions'].append(question) 14 | with open('./playground/data/eval/seed_bench/SEED-Bench-image.json', 'w') as f: 15 | json.dump(data_new, f, indent=4) -------------------------------------------------------------------------------- /scripts/v1_5/eval/seed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | model_name_replace=${model_name//\//_} 4 | echo $model_name_replace 5 | 6 | 7 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 8 | IFS=',' read -ra GPULIST <<< "$gpu_list" 9 | 10 | CHUNKS=${#GPULIST[@]} 11 | 12 | # CKPT="llava-v1.5-13b" 13 | CKPT=$model_name_replace # "llava-v1.5-13b" 14 | 15 | # for IDX in $(seq 0 $((CHUNKS-1))); do 16 | # CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 17 | # --model-path ./checkpoints/$model_name_replace \ 18 | # --question-file ./playground/data/eval/seed_bench/llava-seed-bench.jsonl \ 19 | # --image-folder ./playground/data/eval/seed_bench \ 20 | # --answers-file ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl \ 21 | # --num-chunks $CHUNKS \ 22 | # --chunk-idx $IDX \ 23 | # --temperature 0 \ 24 | # --conv-mode vicuna_v1 & 25 | # done 26 | 27 | # wait 28 | 29 | output_file=./playground/data/eval/seed_bench/answers/$CKPT/merge.jsonl 30 | 31 | # # Clear out the output file if it exists. 32 | # > "$output_file" 33 | 34 | # # Loop through the indices and concatenate each file. 35 | # for IDX in $(seq 0 $((CHUNKS-1))); do 36 | # cat ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 37 | # done 38 | 39 | # Evaluate 40 | python scripts/convert_seed_for_submission.py \ 41 | --annotation-file ./playground/data/eval/seed_bench/SEED-Bench.json \ 42 | --result-file $output_file \ 43 | --result-upload-file ./playground/data/eval/seed_bench/answers_upload/llava-v1.5-13b.jsonl 44 | 45 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | model_name_replace=${model_name//\//_} 4 | echo $model_name_replace 5 | 6 | python -m llava.eval.model_vqa_science \ 7 | --model-path ./checkpoints/$model_name \ 8 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 9 | --image-folder ./playground/data/eval/scienceqa/images/test \ 10 | --answers-file ./playground/data/eval/scienceqa/answers/$model_name_replace.jsonl \ 11 | --single-pred-prompt \ 12 | --temperature 0 \ 13 | --conv-mode vicuna_v1 14 | 15 | python llava/eval/eval_science_qa.py \ 16 | --base-dir ./playground/data/eval/scienceqa \ 17 | --result-file ./playground/data/eval/scienceqa/answers/$model_name_replace.jsonl \ 18 | --output-file ./playground/data/eval/scienceqa/answers/$model_name_replace_output.jsonl \ 19 | --output-result ./playground/data/eval/scienceqa/answers/$model_name_replace_result.json 20 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/textvqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | model_name=$1 3 | model_name_replace=${model_name//\//_} 4 | echo $model_name_replace 5 | 6 | python -m llava.eval.model_vqa_loader \ 7 | --model-path ./checkpoints/$model_name_replace \ 8 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 9 | --image-folder ./playground/data/eval/textvqa/train_images \ 10 | --answers-file ./playground/data/eval/textvqa/answers/$model_name_replace.jsonl \ 11 | --temperature 0 \ 12 | --conv-mode vicuna_v1 13 | 14 | python -m llava.eval.eval_textvqa \ 15 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ 16 | --result-file ./playground/data/eval/textvqa/answers/$model_name_replace.jsonl 17 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/vizwiz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_name=$1 4 | model_name_replace=${model_name//\//_} 5 | echo $model_name_replace 6 | 7 | python -m llava.eval.model_vqa_loader \ 8 | --model-path ./checkpoints/$model_name_replace \ 9 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 10 | --image-folder ./playground/data/eval/vizwiz/test \ 11 | --answers-file ./playground/data/eval/vizwiz/answers/$model_name_replace.jsonl \ 12 | --temperature 0 \ 13 | --conv-mode vicuna_v1 14 | 15 | python scripts/convert_vizwiz_for_submission.py \ 16 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 17 | --result-file ./playground/data/eval/vizwiz/answers/$model_name_replace.jsonl \ 18 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/$model_name_replace.json 19 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/vqav2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | 8 | model_name=$1 9 | model_name_replace=${model_name//\//_} 10 | echo $model_name_replace 11 | 12 | 13 | CKPT=$model_name_replace # "llava-v1.5-13b" 14 | SPLIT="llava_vqav2_mscoco_test-dev2015" 15 | 16 | for IDX in $(seq 0 $((CHUNKS-1))); do 17 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 18 | --model-path ./checkpoints/$model_name_replace \ 19 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 20 | --image-folder ./playground/data/eval/vqav2/test2015 \ 21 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 22 | --num-chunks $CHUNKS \ 23 | --chunk-idx $IDX \ 24 | --temperature 0 \ 25 | --conv-mode vicuna_v1 & 26 | done 27 | 28 | wait 29 | 30 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl 31 | 32 | # Clear out the output file if it exists. 33 | > "$output_file" 34 | 35 | # Loop through the indices and concatenate each file. 36 | for IDX in $(seq 0 $((CHUNKS-1))); do 37 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 38 | done 39 | 40 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT 41 | 42 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "none", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------