├── .gitignore ├── LICENSE ├── README.md ├── checkpoints └── .gitkeep ├── docs ├── Evaluation.md └── Model_Zoo.md ├── imp_llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_merge.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_qbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── mm_utils.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── language_model │ │ ├── imp.py │ │ ├── imp_phi3.py │ │ ├── imp_qwen1_5.py │ │ ├── llava_llama.py │ │ ├── phi2 │ │ │ ├── configuration_phi.py │ │ │ └── modeling_phi.py │ │ ├── phi3 │ │ │ ├── configuration_phi3.py │ │ │ ├── modeling_phi3.py │ │ │ └── sample_finetune.py │ │ └── qwen2 │ │ │ ├── config.json │ │ │ ├── configuration_qwen2.py │ │ │ ├── modeling_qwen2.py │ │ │ ├── tokenization_qwen2.py │ │ │ └── tokenizer_config.json │ ├── llava_arch.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ └── siglip │ │ │ ├── configuration_siglip.py │ │ │ ├── image_processing_imp.py │ │ │ └── modeling_siglip.py │ └── multimodal_projector │ │ └── builder.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── logs └── .gitkeep ├── playground └── data │ ├── eval │ └── .gitkeep │ └── prompts │ ├── complex_reasoning │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ ├── 002_caps.txt │ ├── 002_conv.txt │ └── system_message.txt │ ├── conversation │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ └── system_message.txt │ └── detail_description │ ├── 000_caps.txt │ ├── 000_conv.txt │ ├── 001_caps.txt │ ├── 001_conv.txt │ ├── 002_caps.txt │ ├── 002_conv.txt │ └── system_message.txt ├── requirements.txt ├── scripts ├── convert_gqa_for_eval.py ├── convert_mmbench_for_submission.py ├── convert_mmvet_for_eval.py ├── convert_seed_for_submission.py ├── convert_sqa_to_llava.py ├── convert_sqa_to_llava_base_prompt.py ├── convert_vizwiz_for_submission.py ├── convert_vqav2_for_submission.py ├── download_models.py ├── eval │ ├── gqa.sh │ ├── mmbench.sh │ ├── mme.sh │ ├── mmvet.sh │ ├── pope.sh │ ├── sqa.sh │ ├── textvqa.sh │ ├── vizwiz.sh │ └── vqav2.sh ├── extract_mm_projector.py ├── finetune.sh ├── finetune_lora.sh ├── finetune_lora_custom.sh ├── merge.sh ├── merge_lora_weights.py ├── pretrain.sh ├── zero2.json ├── zero3.json └── zero3_offload.json └── tmp └── .gitkeep /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | checkpoints/* 163 | !checkpoints/.gitkeep 164 | 165 | playground/data/eval/* 166 | !playground/data/eval/.gitkeep 167 | 168 | logs/* 169 | !logs/.gitkeep 170 | scripts/eval/aok_qwen.sh 171 | scripts/finetune_lora_custom_oy.sh 172 | tmp/* 173 | !tmp/.gitkeep 174 | scripts/test/* -------------------------------------------------------------------------------- /checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/checkpoints/.gitkeep -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | We follow the evaluation of [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA/tree/main) and conduct experiments on 9 commonly-used benchmarks, including 5 academic VQA benchmarks and 4 popular MLLM benchmarks. All evaluation scripts are placed in the `scripts/eval` folder. 3 | 4 | Before preparing task-specific data, you should download [eval.zip](https://drive.google.com/file/d/1atZSBBrAX54yYpxtVVW33zFvcnaHeFPy/view?usp=sharing) and extract to `./playground/data/eval`. For more specific instructions, please refer to [LLaVA's Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md). 5 | 6 | You can choose to use `lora eval` or `merge eval` in evaluation scripts. 7 | ## Scripts 8 | ### VQAv2 9 | 1. Download [`test2015`](http://images.cocodataset.org/zips/test2015.zip) and put it under `./playground/data/eval/vqav2`. 10 | 2. Inference. 11 | ```Shell 12 | CUDA_VISIBLE_DEVICES=0,1,2,3,4 bash scripts/eval/vqav2.sh 13 | ``` 14 | 3. The result file could be found in `./playground/data/eval/vqav2/answers_upload`. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/830/my-submission). 15 | 16 | ### VisWiz 17 | 18 | 1. Download [`test.json`](https://vizwiz.cs.colorado.edu/VizWiz_final/vqa_data/Annotations.zip) and extract [`test.zip`](https://vizwiz.cs.colorado.edu/VizWiz_final/images/test.zip) to `test`. Put them under `./playground/data/eval/vizwiz`. 19 | 2. Inference. 20 | ```Shell 21 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/eval/vizwiz.sh 22 | ``` 23 | 3. The result file could be found in `./playground/data/eval/vizwiz/answers_upload`. Submit the results to the [evaluation server](https://eval.ai/web/challenges/challenge-page/1911/my-submission) 24 | 25 | ### MMBench 26 | 27 | 1. Download [`mmbench_dev_20230712.tsv`](https://download.openmmlab.com/mmclassification/datasets/mmbench/mmbench_dev_20230712.tsv) and put under `./playground/data/eval/mmbench`. 28 | 2. Inference. 29 | ```Shell 30 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/mmbench.sh 31 | ``` 32 | 3. The result file could be found in `./playground/data/eval/mmbench/answers_upload`. Submit the results to the [evaluation server](https://opencompass.org.cn/leaderboard-multimodal) 33 | 34 | ### MM-Vet 35 | 36 | 1. Extract [`mm-vet.zip`](https://github.com/yuweihao/MM-Vet/releases/download/v1/mm-vet.zip) to `./playground/data/eval/mmvet`. 37 | 2. Inference. 38 | ```Shell 39 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/mmvet.sh 40 | ``` 41 | 3. Evaluate the predictions in `./playground/data/eval/mmvet/results` using the official jupyter notebook. 42 | 43 | ### GQA 44 | 45 | 1. Download the [data](https://cs.stanford.edu/people/dorarad/gqa/download.html) and [evaluation scripts](https://cs.stanford.edu/people/dorarad/gqa/evaluate.html) following the official instructions and put under `./playground/data/eval/gqa/data`. You may need to modify `eval.py` as [this](https://gist.github.com/haotian-liu/db6eddc2a984b4cbcc8a7f26fd523187) due to the missing assets in the GQA v1.2 release. 46 | 2. Inference and evaluate. 47 | ```Shell 48 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/eval/gqa.sh 49 | ``` 50 | 51 | ### ScienceQA 52 | 53 | 1. Under `./playground/data/eval/scienceqa`, download `images`, `pid_splits.json`, `problems.json` from the `data/scienceqa` folder of the ScienceQA [repo](https://github.com/lupantech/ScienceQA). 54 | 2. Inference and evaluate. 55 | ```Shell 56 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/eval/sqa.sh 57 | ``` 58 | 3. (Optional) We follow multiple-choice's prompt in [LLaVA's Evaluation.md](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md) rewriting a question file [scienceqa_multi.jsonl](https://drive.google.com/file/d/1SU7tEuXUENnvXowGkFVcEa5dsVPlMsD9/view?usp=drive_link) which gets a better result than `llava_test_CQM-A.json`. Please replace `imp_llava.eval.model_vqa_science` into `imp_llava.eval.model_vqa_loader` in `sqa.sh` while using `scienceqa_multi.jsonl` for evaluation 59 | ### TextVQA 60 | 61 | 1. Download [`TextVQA_0.5.1_val.json`](https://dl.fbaipublicfiles.com/textvqa/data/TextVQA_0.5.1_val.json) and [images](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) and extract to `./playground/data/eval/textvqa`. 62 | 2. Inference and evaluate. 63 | ```Shell 64 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/eval/textvqa.sh 65 | ``` 66 | 67 | ### POPE 68 | 69 | 1. Download `coco` from [POPE](https://github.com/AoiDragon/POPE/tree/e3e39262c85a6a83f26cf5094022a782cb0df58d/output/coco) and put under `./playground/data/eval/pope`. 70 | 2. Inference and evaluate. 71 | ```Shell 72 | CUDA_VISIBLE_DEVICES=0,1,2,3 bash scripts/eval/pope.sh 73 | ``` 74 | 75 | ### MME 76 | 77 | 1. Download the data following the official instructions [here](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation). 78 | 2. Downloaded images to `MME_Benchmark_release_version`. 79 | 3. put the official `eval_tool` and `MME_Benchmark_release_version` under `./playground/data/eval/MME`. 80 | 4. Inference and evaluate. 81 | ```Shell 82 | CUDA_VISIBLE_DEVICES=0 bash scripts/eval/mme.sh 83 | ``` -------------------------------------------------------------------------------- /docs/Model_Zoo.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | We provide a family of Imp models as follows 3 | 4 | ### Imp-v1 5 | 6 | | name | size | LLM |visual encoder |image res.| download link | 7 | |:---------|:----:|:----:|:-------------:|:--------:|:-------:| 8 | | Imp-v1-3B | 3B | phi2 | SigLip |384| [Huggingface](https://huggingface.co/MILVLG/imp-v1-3b) | 9 | 10 | ### Imp-v1.5 11 | | name | size | LLM |visual encoder |image res.| download link | 12 | |:---------|:----:|:----:|:-------------:|:--------:|:------:| 13 | | Imp-v1.5-2B-Qwen1.5 | 2B | Qwen1.5 (1.8B) | SigLip |384|[Huggingface](https://huggingface.co/MILVLG/Imp-v1.5-2B-Qwen1.5)| 14 | | Imp-v1.5-3B-Phi2| 3B | Phi2 (2.7B) | SigLip |384|[Huggingface](https://huggingface.co/MILVLG/Imp-v1.5-3B-Phi2)| 15 | | Imp-v1.5-3B-Phi2-196 | 3B | Phi2 (2.7B) | SigLip |196|[Huggingface](https://huggingface.co/MILVLG/Imp-v1.5-3B-196)| 16 | | Imp-v1.5-4B-Phi3 | 4B | Phi3 (3.8B) | SigLip |384|[Huggingface](https://huggingface.co/MILVLG/Imp-v1.5-4B-Phi3)| -------------------------------------------------------------------------------- /imp_llava/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zhenwei Shao and MILVLG team. 2 | # Licensed under the Apache License, Version 2.0. 3 | 4 | import logging, os 5 | 6 | class VeryUsefulLoggerFormatter(logging.Formatter): 7 | """ A very useful logger formatter lets you locate where a printed log is coming from. 8 | This class is written by Zhenwei (https://github.com/ParadoxZW). 9 | """ 10 | def format(self, record): 11 | pathname = record.pathname 12 | parts = pathname.split(os.sep) 13 | start_idx = max(0, len(parts) - (self.imp_log_fflevel + 1)) 14 | relevant_path = os.sep.join(parts[start_idx:]) 15 | record.custom_path = relevant_path 16 | return super().format(record) 17 | 18 | @classmethod 19 | def init_logger_help_function(cls, name, level=logging.INFO): 20 | imp_silient_others = bool(os.environ.get("IMP_SILIENT_OTHERS", False)) 21 | is_silent = imp_silient_others and os.environ.get("LOCAL_RANK", None) not in ["0", None] 22 | logger = logging.getLogger(name) 23 | logger.setLevel(logging.ERROR if is_silent else level) 24 | logger.propagate = False 25 | # customize log format 26 | log_format = "[%(asctime)s] [%(levelname)s] [%(custom_path)s:%(lineno)d] %(message)s" 27 | # log_format = "[%(asctime)s] [logger:%(name)s] [%(levelname)s] [%(custom_path)s:%(lineno)d] %(message)s" 28 | formatter = cls(log_format, datefmt="%Y-%m-%d %H:%M:%S") 29 | formatter.imp_log_fflevel = int(os.environ.get("IMP_LOG_FFLEVEL", "3")) 30 | handler = logging.StreamHandler() 31 | handler.setFormatter(formatter) 32 | logger.addHandler(handler) 33 | return logger 34 | 35 | 36 | logger = VeryUsefulLoggerFormatter.init_logger_help_function(__name__) 37 | VeryUsefulLoggerFormatter.init_logger_help_function("", level=logging.WARNING) 38 | VeryUsefulLoggerFormatter.init_logger_help_function("transformers.generation", level=logging.WARNING) 39 | VeryUsefulLoggerFormatter.init_logger_help_function("transformers.modeling_utils", level=logging.ERROR) 40 | # VeryUsefulLoggerFormatter.init_logger_help_function("deepspeed") 41 | 42 | if os.environ.get("LOCAL_RANK", None) in ["0", None]: 43 | logger.info( 44 | f"\n\n\033[95m\033[4mWelcome to Imp! We use a custom logger in the Imp project. It is supported to use environment variables to control the logger:\n" 45 | " - `export IMP_LOG_FFLEVEL={number}` to set the number of father folders to be printed.\n" 46 | " - `export IMP_SILIENT_OTHERS=true` to set multiple processes to be silent except the rank-0 process, which is useful for distributed training.\n" 47 | "You are free to access the code where this info is came from and modify the log behavior. The Imp team wishes you a good day:)\033[0m\033[24m\n" 48 | ) 49 | 50 | try: 51 | from .model import LlavaLlamaForCausalLM 52 | except: 53 | pass -------------------------------------------------------------------------------- /imp_llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | is_multimodal=False 62 | else: 63 | pred = predictions[prob_id] 64 | pred_text = pred['text'] 65 | is_multimodal=True 66 | 67 | if pred_text in args.options: 68 | answer = pred_text 69 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 70 | answer = pred_text[0] 71 | else: 72 | pattern = re.compile(r'The answer is ([A-Z]).') 73 | res = pattern.findall(pred_text) 74 | if len(res) == 1: 75 | answer = res[0] # 'A', 'B', ... 76 | else: 77 | answer = "FAILED" 78 | 79 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 80 | 81 | analysis = { 82 | 'question_id': prob_id, 83 | 'parsed_ans': answer, 84 | 'ground_truth': args.options[prob['answer']], 85 | 'question': pred['prompt'], 86 | 'pred': pred_text, 87 | 'is_multimodal': is_multimodal, #'' in pred['prompt'], 88 | } 89 | 90 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 91 | sqa_results['outputs'][prob_id] = pred_text 92 | 93 | if pred_idx == prob['answer']: 94 | results['correct'].append(analysis) 95 | else: 96 | results['incorrect'].append(analysis) 97 | 98 | correct = len(results['correct']) 99 | total = len(results['correct']) + len(results['incorrect']) 100 | 101 | ###### IMG ###### 102 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 103 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 104 | multimodal_total = multimodal_correct + multimodal_incorrect 105 | ###### IMG ###### 106 | 107 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 108 | 109 | sqa_results['acc'] = correct / total * 100 110 | sqa_results['correct'] = correct 111 | sqa_results['count'] = total 112 | 113 | with open(args.output_file, 'w') as f: 114 | json.dump(results, f, indent=2) 115 | with open(args.output_result, 'w') as f: 116 | json.dump(sqa_results, f, indent=2) 117 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /imp_llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from imp_llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /imp_llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /imp_llava/eval/model_merge.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from imp_llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from imp_llava.conversation import conv_templates, SeparatorStyle 10 | from imp_llava.model.builder import load_pretrained_model 11 | from imp_llava.utils import disable_torch_init 12 | from imp_llava.mm_utils import get_model_name_from_path 13 | 14 | 15 | 16 | def eval_model(args): 17 | # Model 18 | disable_torch_init() 19 | model_path = os.path.expanduser(args.model_path) 20 | model_name = get_model_name_from_path(model_path) 21 | tokenizer, model, _, _ = load_pretrained_model(model_path, args.model_base, model_name) 22 | model.save_pretrained(f'checkpoints/{args.save_name}/', max_shard_size="1024MB", safe_serialization=True) 23 | tokenizer.save_pretrained(f'checkpoints/{args.save_name}/') 24 | 25 | if __name__ == "__main__": 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 28 | parser.add_argument("--model-base", type=str, default=None) 29 | parser.add_argument("--save-name", type=str, default='imp-v1-3b') 30 | args = parser.parse_args() 31 | 32 | eval_model(args) 33 | -------------------------------------------------------------------------------- /imp_llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from imp_llava.conversation import default_conversation 10 | from imp_llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /imp_llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from imp_llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from imp_llava.conversation import conv_templates, SeparatorStyle 10 | from imp_llava.model.builder import load_pretrained_model 11 | from imp_llava.utils import disable_torch_init 12 | from imp_llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True if args.temperature > 0 else False, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /imp_llava/eval/model_vqa_qbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from tqdm import tqdm 4 | import json 5 | 6 | from imp_llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 7 | from imp_llava.conversation import conv_templates, SeparatorStyle 8 | from imp_llava.model.builder import load_pretrained_model 9 | from imp_llava.utils import disable_torch_init 10 | from imp_llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 11 | 12 | from PIL import Image 13 | 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | 19 | def load_image(image_file): 20 | if image_file.startswith('http') or image_file.startswith('https'): 21 | response = requests.get(image_file) 22 | image = Image.open(BytesIO(response.content)).convert('RGB') 23 | else: 24 | image = Image.open(image_file).convert('RGB') 25 | return image 26 | 27 | 28 | def eval_model(args): 29 | # Model 30 | disable_torch_init() 31 | 32 | model_name = get_model_name_from_path(args.model_path) 33 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True) 34 | 35 | 36 | 37 | 38 | with open(args.questions_file) as f: 39 | llvqa_data = json.load(f) 40 | 41 | for i, llddata in enumerate(tqdm(llvqa_data)): 42 | filename = llddata["img_path"] 43 | if args.lang == "en": 44 | message = llddata["question"] + "\nChoose between one of the options as follows:\n" 45 | elif args.lang == "zh": 46 | message = llddata["question"] + "\在下列选项中选择一个:\n" 47 | else: 48 | raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.") 49 | for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]): 50 | message += f"{choice} {ans}\n" 51 | qs = message 52 | 53 | if model.config.mm_use_im_start_end: 54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 55 | else: 56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 57 | 58 | if 'llama-2' in model_name.lower(): 59 | conv_mode = "llava_llama_2" 60 | elif "v1" in model_name.lower(): 61 | conv_mode = "llava_v1" 62 | elif "mpt" in model_name.lower(): 63 | conv_mode = "mpt" 64 | else: 65 | conv_mode = "llava_v0" 66 | 67 | if args.conv_mode is not None and conv_mode != args.conv_mode: 68 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 69 | else: 70 | args.conv_mode = conv_mode 71 | 72 | conv = conv_templates[args.conv_mode].copy() 73 | conv.append_message(conv.roles[0], qs) 74 | conv.append_message(conv.roles[1], None) 75 | prompt = conv.get_prompt() 76 | 77 | image = load_image(args.image_folder + filename) 78 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 79 | 80 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 81 | 82 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 83 | keywords = [stop_str] 84 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 85 | 86 | 87 | with torch.inference_mode(): 88 | output_ids = model.generate( 89 | input_ids, 90 | images=image_tensor, 91 | num_beams=1, 92 | do_sample=False, 93 | temperature=0, 94 | max_new_tokens=1024, 95 | use_cache=True, 96 | stopping_criteria=[stopping_criteria]) 97 | 98 | input_token_len = input_ids.shape[1] 99 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 100 | if n_diff_input_output > 0: 101 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 102 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 103 | outputs = outputs.strip() 104 | if outputs.endswith(stop_str): 105 | outputs = outputs[:-len(stop_str)] 106 | outputs = outputs.strip() 107 | llddata["response"] = outputs 108 | with open(args.answers_file, "a") as wf: 109 | json.dump(llddata, wf) 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="llava-v1.5") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa") 116 | parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json") 117 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 118 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 119 | parser.add_argument("--lang", type=str, default="en") 120 | args = parser.parse_args() 121 | 122 | eval_model(args) 123 | -------------------------------------------------------------------------------- /imp_llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from imp_llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from imp_llava.conversation import conv_templates, SeparatorStyle 10 | from imp_llava.model.builder import load_pretrained_model 11 | from imp_llava.utils import disable_torch_init 12 | from imp_llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | keywords = [''] 37 | 38 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 39 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 40 | answers_file = os.path.expanduser(args.answers_file) 41 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 42 | ans_file = open(answers_file, "w") 43 | for i, line in enumerate(tqdm(questions)): 44 | idx = line["id"] 45 | question = line['conversations'][0] 46 | qs = question['value'].replace('', '').strip() 47 | cur_prompt = qs 48 | 49 | if 'image' in line: 50 | image_file = line["image"] 51 | image = Image.open(os.path.join(args.image_folder, image_file)) 52 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 53 | images = image_tensor.unsqueeze(0).half().cuda() 54 | if getattr(model.config, 'mm_use_im_start_end', False): 55 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 56 | else: 57 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 58 | cur_prompt = '' + '\n' + cur_prompt 59 | else: 60 | images = None 61 | 62 | if args.single_pred_prompt: 63 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 64 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." 65 | 66 | conv = conv_templates[args.conv_mode].copy() 67 | conv.append_message(conv.roles[0], qs) 68 | conv.append_message(conv.roles[1], None) 69 | prompt = conv.get_prompt() 70 | 71 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 72 | 73 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 74 | # keywords = [stop_str] 75 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] 76 | 77 | with torch.inference_mode(): 78 | output_ids = model.generate( 79 | input_ids, 80 | images=images, 81 | do_sample=True if args.temperature > 0 else False, 82 | temperature=args.temperature, 83 | max_new_tokens=1024, 84 | use_cache=True, 85 | stopping_criteria=stopping_criteria, 86 | ) 87 | 88 | input_token_len = input_ids.shape[1] 89 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 90 | if n_diff_input_output > 0: 91 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 92 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 93 | # outputs = outputs.strip() 94 | # if outputs.endswith(stop_str): 95 | # outputs = outputs[:-len(stop_str)] 96 | outputs = outputs.strip() 97 | 98 | # prompt for answer 99 | if args.answer_prompter: 100 | outputs_reasoning = outputs 101 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 102 | 103 | with torch.inference_mode(): 104 | output_ids = model.generate( 105 | input_ids, 106 | images=images, 107 | do_sample=True if args.temperature > 0 else False, 108 | temperature=args.temperature, 109 | max_new_tokens=64, 110 | use_cache=True, 111 | stopping_criteria=[stopping_criteria]) 112 | 113 | input_token_len = input_ids.shape[1] 114 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 115 | if n_diff_input_output > 0: 116 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 117 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 118 | outputs = outputs.strip() 119 | if outputs.endswith(stop_str): 120 | outputs = outputs[:-len(stop_str)] 121 | outputs = outputs.strip() 122 | outputs = outputs_reasoning + '\n The answer is ' + outputs 123 | 124 | ans_id = shortuuid.uuid() 125 | ans_file.write(json.dumps({"question_id": idx, 126 | "prompt": cur_prompt, 127 | "text": outputs, 128 | "answer_id": ans_id, 129 | "model_id": model_name, 130 | "metadata": {}}) + "\n") 131 | ans_file.flush() 132 | ans_file.close() 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 137 | parser.add_argument("--model-base", type=str, default=None) 138 | parser.add_argument("--image-folder", type=str, default="") 139 | parser.add_argument("--question-file", type=str, default="tables/question.json") 140 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 141 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 142 | parser.add_argument("--num-chunks", type=int, default=1) 143 | parser.add_argument("--chunk-idx", type=int, default=0) 144 | parser.add_argument("--temperature", type=float, default=0.2) 145 | parser.add_argument("--answer-prompter", action="store_true") 146 | parser.add_argument("--single-pred-prompt", action="store_true") 147 | args = parser.parse_args() 148 | 149 | eval_model(args) 150 | -------------------------------------------------------------------------------- /imp_llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /imp_llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from imp_llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from imp_llava.conversation import conv_templates, SeparatorStyle 12 | from imp_llava.model.builder import load_pretrained_model 13 | from imp_llava.utils import disable_torch_init 14 | from imp_llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | KeywordsStoppingCriteria, 19 | ) 20 | 21 | from PIL import Image 22 | 23 | import requests 24 | from PIL import Image 25 | from io import BytesIO 26 | import re 27 | 28 | 29 | def image_parser(args): 30 | out = args.image_file.split(args.sep) 31 | return out 32 | 33 | 34 | def load_image(image_file): 35 | if image_file.startswith("http") or image_file.startswith("https"): 36 | response = requests.get(image_file) 37 | image = Image.open(BytesIO(response.content)).convert("RGB") 38 | else: 39 | image = Image.open(image_file).convert("RGB") 40 | return image 41 | 42 | 43 | def load_images(image_files): 44 | out = [] 45 | for image_file in image_files: 46 | image = load_image(image_file) 47 | out.append(image) 48 | return out 49 | 50 | 51 | def eval_model(args): 52 | # Model 53 | disable_torch_init() 54 | 55 | model_name = get_model_name_from_path(args.model_path) 56 | tokenizer, model, image_processor, context_len = load_pretrained_model( 57 | args.model_path, args.model_base, model_name 58 | ) 59 | 60 | qs = args.query 61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 62 | if IMAGE_PLACEHOLDER in qs: 63 | if model.config.mm_use_im_start_end: 64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 65 | else: 66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 67 | else: 68 | if model.config.mm_use_im_start_end: 69 | qs = image_token_se + "\n" + qs 70 | else: 71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 72 | 73 | if "llama-2" in model_name.lower(): 74 | conv_mode = "llava_llama_2" 75 | elif "v1" in model_name.lower(): 76 | conv_mode = "llava_v1" 77 | elif "mpt" in model_name.lower(): 78 | conv_mode = "mpt" 79 | else: 80 | conv_mode = "llava_v0" 81 | 82 | if args.conv_mode is not None and conv_mode != args.conv_mode: 83 | print( 84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 85 | conv_mode, args.conv_mode, args.conv_mode 86 | ) 87 | ) 88 | else: 89 | args.conv_mode = conv_mode 90 | 91 | conv = conv_templates[args.conv_mode].copy() 92 | conv.append_message(conv.roles[0], qs) 93 | conv.append_message(conv.roles[1], None) 94 | prompt = conv.get_prompt() 95 | 96 | image_files = image_parser(args) 97 | images = load_images(image_files) 98 | images_tensor = process_images( 99 | images, 100 | image_processor, 101 | model.config 102 | ).to(model.device, dtype=torch.float16) 103 | 104 | input_ids = ( 105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 106 | .unsqueeze(0) 107 | .cuda() 108 | ) 109 | 110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 111 | keywords = [stop_str] 112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | do_sample=True if args.temperature > 0 else False, 119 | temperature=args.temperature, 120 | top_p=args.top_p, 121 | num_beams=args.num_beams, 122 | max_new_tokens=args.max_new_tokens, 123 | use_cache=True, 124 | stopping_criteria=[stopping_criteria], 125 | ) 126 | 127 | input_token_len = input_ids.shape[1] 128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 129 | if n_diff_input_output > 0: 130 | print( 131 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 132 | ) 133 | outputs = tokenizer.batch_decode( 134 | output_ids[:, input_token_len:], skip_special_tokens=True 135 | )[0] 136 | outputs = outputs.strip() 137 | if outputs.endswith(stop_str): 138 | outputs = outputs[: -len(stop_str)] 139 | outputs = outputs.strip() 140 | print(outputs) 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 146 | parser.add_argument("--model-base", type=str, default=None) 147 | parser.add_argument("--image-file", type=str, required=True) 148 | parser.add_argument("--query", type=str, required=True) 149 | parser.add_argument("--conv-mode", type=str, default=None) 150 | parser.add_argument("--sep", type=str, default=",") 151 | parser.add_argument("--temperature", type=float, default=0.2) 152 | parser.add_argument("--top_p", type=float, default=None) 153 | parser.add_argument("--num_beams", type=int, default=1) 154 | parser.add_argument("--max_new_tokens", type=int, default=512) 155 | args = parser.parse_args() 156 | 157 | eval_model(args) 158 | -------------------------------------------------------------------------------- /imp_llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /imp_llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /imp_llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /imp_llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/imp_llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/imp_llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/imp_llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imp_llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/imp_llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /imp_llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /imp_llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from imp_llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | class KeywordsStoppingCriteria(StoppingCriteria): 74 | def __init__(self, keywords, tokenizer, input_ids): 75 | self.keywords = keywords 76 | self.keyword_ids = [] 77 | self.max_keyword_len = 0 78 | for keyword in keywords: 79 | cur_keyword_ids = tokenizer(keyword).input_ids 80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 81 | cur_keyword_ids = cur_keyword_ids[1:] 82 | if len(cur_keyword_ids) > self.max_keyword_len: 83 | self.max_keyword_len = len(cur_keyword_ids) 84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 85 | self.tokenizer = tokenizer 86 | self.start_len = input_ids.shape[1] 87 | 88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 91 | for keyword_id in self.keyword_ids: 92 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 93 | return True 94 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 95 | for keyword in self.keywords: 96 | if keyword in outputs: 97 | return True 98 | return False 99 | 100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 101 | outputs = [] 102 | for i in range(output_ids.shape[0]): 103 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 104 | return all(outputs) 105 | -------------------------------------------------------------------------------- /imp_llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | except: 4 | pass 5 | 6 | from .language_model.imp_qwen1_5 import ImpQwen2ForCausalLM, ImpQwen2Config 7 | from .language_model.imp_phi3 import ImpPhi3Config, ImpPhi3ForCausalLM 8 | from .language_model.imp import ImpConfig, ImpForCausalLM 9 | 10 | -------------------------------------------------------------------------------- /imp_llava/model/language_model/imp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zhenwei Shao and MILVLG team. 2 | # Licensed under the Apache License, Version 2.0. 3 | 4 | 5 | from typing import List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from transformers import AutoConfig, AutoModelForCausalLM 11 | 12 | from .phi2.modeling_phi import PhiConfig, PhiModel, PhiForCausalLM,PhiPreTrainedModel 13 | from transformers.modeling_outputs import CausalLMOutputWithPast 14 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 15 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 16 | 17 | class ImpConfig(PhiConfig): 18 | model_type = "imp" 19 | 20 | def __init__(self, **kwargs): 21 | super().__init__(**kwargs) 22 | self.image_token_index = getattr(self, "image_token_index", 50296) 23 | self.image_token = getattr(self, "image_token", "") 24 | 25 | 26 | class ImpModel(LlavaMetaModel, PhiModel): 27 | config_class = ImpConfig 28 | 29 | def __init__(self, config: ImpConfig): 30 | super(ImpModel, self).__init__(config) 31 | 32 | 33 | class ImpForCausalLM(PhiPreTrainedModel, LlavaMetaForCausalLM): 34 | """Imp for Causal Language Modeling.""" 35 | 36 | # _keys_to_ignore_on_load_missing = [""] 37 | # _keys_to_ignore_on_load_unexpected = [r"transformer\.h\.\d+\.mlp.(fc_in|fc_out)\.(weight|bias)"] 38 | config_class = ImpConfig 39 | 40 | def __init__(self, config: ImpConfig) -> None: 41 | super().__init__(config) 42 | 43 | self.model = ImpModel(config) 44 | self.vocab_size = config.vocab_size 45 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) 46 | 47 | self.post_init() 48 | 49 | def get_input_embeddings(self): 50 | return self.model.embed_tokens 51 | 52 | def set_input_embeddings(self, value): 53 | self.model.embed_tokens = value 54 | 55 | def get_output_embeddings(self) -> nn.Linear: 56 | return self.lm_head 57 | 58 | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: 59 | self.lm_head = new_embeddings 60 | 61 | def get_model(self): 62 | return self.model 63 | 64 | def get_decoder(self): 65 | return self.model 66 | 67 | def set_decoder(self, decoder): 68 | self.model = decoder 69 | 70 | def image_preprocess(self, images): 71 | return self.get_vision_tower().image_processor(images)['pixel_values'] 72 | 73 | def forward( 74 | self, 75 | input_ids: torch.LongTensor = None, 76 | attention_mask: Optional[torch.Tensor] = None, 77 | position_ids: Optional[torch.LongTensor] = None, 78 | past_key_values: Optional[List[torch.FloatTensor]] = None, 79 | inputs_embeds: Optional[torch.FloatTensor] = None, 80 | labels: Optional[torch.LongTensor] = None, 81 | use_cache: Optional[bool] = None, 82 | output_attentions: Optional[bool] = None, 83 | output_hidden_states: Optional[bool] = None, 84 | images: Optional[torch.FloatTensor] = None, 85 | return_dict: Optional[bool] = None, 86 | ) -> Union[Tuple, CausalLMOutputWithPast]: 87 | 88 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 89 | output_hidden_states = ( 90 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 91 | ) 92 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 93 | 94 | if inputs_embeds is None: 95 | ( 96 | input_ids, 97 | position_ids, 98 | attention_mask, 99 | past_key_values, 100 | inputs_embeds, 101 | labels 102 | ) = self.prepare_inputs_labels_for_multimodal( 103 | input_ids, 104 | position_ids, 105 | attention_mask, 106 | past_key_values, 107 | labels, 108 | images, 109 | 'phi2' 110 | ) 111 | 112 | outputs = self.model( 113 | input_ids=input_ids, 114 | past_key_values=past_key_values, 115 | attention_mask=attention_mask, 116 | position_ids=position_ids, 117 | inputs_embeds=inputs_embeds, 118 | use_cache=use_cache, 119 | output_attentions=output_attentions, 120 | output_hidden_states=output_hidden_states, 121 | return_dict=return_dict 122 | ) 123 | hidden_states = outputs[0] 124 | logits = self.lm_head(hidden_states) 125 | logits = logits.float() 126 | 127 | loss = None 128 | if labels is not None: 129 | # Shift so that tokens < n predict n 130 | shift_logits = logits[..., :-1, :].contiguous() 131 | shift_labels = labels[..., 1:].contiguous() 132 | # Flatten the tokens 133 | loss_fct = CrossEntropyLoss() 134 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 135 | shift_labels = shift_labels.view(-1) 136 | # Enable model parallelism 137 | shift_labels = shift_labels.to(shift_logits.device) 138 | loss = loss_fct(shift_logits, shift_labels) 139 | if not return_dict: 140 | loss = None 141 | output = (logits,) + outputs[1:] 142 | return (loss,) + output if loss is not None else output 143 | 144 | return CausalLMOutputWithPast( 145 | loss=loss, 146 | logits=logits, 147 | past_key_values=outputs.past_key_values, 148 | hidden_states=outputs.hidden_states, 149 | attentions=outputs.attentions, 150 | ) 151 | 152 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 153 | images = kwargs.pop("images", None) 154 | _inputs = super().prepare_inputs_for_generation( 155 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 156 | ) 157 | if images is not None: 158 | _inputs['images'] = images 159 | return _inputs 160 | 161 | AutoConfig.register("imp", ImpConfig) 162 | AutoModelForCausalLM.register(ImpConfig, ImpForCausalLM) 163 | -------------------------------------------------------------------------------- /imp_llava/model/language_model/imp_phi3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM 22 | 23 | from .phi3.modeling_phi3 import Phi3Config, Phi3Model, Phi3ForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | 28 | class ImpPhi3Config(Phi3Config): 29 | model_type = "imp_phi3" 30 | 31 | def __init__(self, **kwargs): 32 | super().__init__(**kwargs) 33 | # self.image_token_index = getattr(self, "image_token_index", 50296) 34 | # self.image_token = getattr(self, "image_token", "") 35 | 36 | 37 | class ImpPhi3Model(LlavaMetaModel, Phi3Model): 38 | config_class = ImpPhi3Config 39 | 40 | def __init__(self, config: ImpPhi3Config): 41 | super(ImpPhi3Model, self).__init__(config) 42 | 43 | 44 | class ImpPhi3ForCausalLM(Phi3ForCausalLM, LlavaMetaForCausalLM): 45 | config_class = ImpPhi3Config 46 | 47 | def __init__(self, config): 48 | super(ImpPhi3ForCausalLM, self).__init__(config) 49 | self.model = ImpPhi3Model(config) 50 | self.vocab_size = config.vocab_size 51 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 52 | self.need_clear_cache = False 53 | # self.lm_head.weight = self.model.embed_tokens.weight 54 | 55 | # Initialize weights and apply final processing 56 | self.post_init() 57 | 58 | def get_model(self): 59 | return self.model 60 | 61 | def forward( 62 | self, 63 | input_ids: torch.LongTensor = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | position_ids: Optional[torch.LongTensor] = None, 66 | past_key_values: Optional[List[torch.FloatTensor]] = None, 67 | inputs_embeds: Optional[torch.FloatTensor] = None, 68 | labels: Optional[torch.LongTensor] = None, 69 | use_cache: Optional[bool] = None, 70 | output_attentions: Optional[bool] = None, 71 | output_hidden_states: Optional[bool] = None, 72 | images: Optional[torch.FloatTensor] = None, 73 | return_dict: Optional[bool] = None, 74 | ) -> Union[Tuple, CausalLMOutputWithPast]: 75 | 76 | if inputs_embeds is None: 77 | ( 78 | input_ids, 79 | position_ids, 80 | attention_mask, 81 | past_key_values, 82 | inputs_embeds, 83 | labels 84 | ) = self.prepare_inputs_labels_for_multimodal( 85 | input_ids, 86 | position_ids, 87 | attention_mask, 88 | past_key_values, 89 | labels, 90 | images, 91 | 'phi3' 92 | ) 93 | # inputs_embeds.requires_grad_(True) 94 | return super().forward( 95 | input_ids=input_ids, 96 | attention_mask=attention_mask, 97 | position_ids=position_ids, 98 | past_key_values=past_key_values, 99 | inputs_embeds=inputs_embeds, 100 | labels=labels, 101 | use_cache=use_cache, 102 | output_attentions=output_attentions, 103 | output_hidden_states=output_hidden_states, 104 | return_dict=return_dict 105 | ) 106 | 107 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 108 | images = kwargs.pop("images", None) 109 | _inputs = super().prepare_inputs_for_generation( 110 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 111 | ) 112 | if images is not None: 113 | _inputs['images'] = images 114 | return _inputs 115 | 116 | AutoConfig.register("imp_phi3", ImpPhi3Config) 117 | AutoModelForCausalLM.register(ImpPhi3Config, ImpPhi3ForCausalLM) 118 | -------------------------------------------------------------------------------- /imp_llava/model/language_model/imp_qwen1_5.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from transformers import AutoConfig, AutoModelForCausalLM 7 | 8 | from .qwen2.modeling_qwen2 import Qwen2Config, Qwen2Model, Qwen2ForCausalLM 9 | 10 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 11 | from transformers.modeling_outputs import CausalLMOutputWithPast 12 | 13 | 14 | class ImpQwen2Config(Qwen2Config): 15 | model_type = "imp_qwen2" 16 | 17 | def __init__(self, **kwargs): 18 | super().__init__(**kwargs) 19 | 20 | class ImpQwen2Model(LlavaMetaModel, Qwen2Model): 21 | config_class = ImpQwen2Config 22 | 23 | def __init__(self, config: ImpQwen2Config): 24 | super(ImpQwen2Model, self).__init__(config) 25 | 26 | 27 | class ImpQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM): 28 | config_class = ImpQwen2Config 29 | 30 | def __init__(self, config): 31 | super(ImpQwen2ForCausalLM, self).__init__(config) 32 | self.model = ImpQwen2Model(config) 33 | self.vocab_size = config.vocab_size 34 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 35 | self.need_clear_cache = False 36 | 37 | # Initialize weights and apply final processing 38 | self.post_init() 39 | 40 | def get_model(self): 41 | return self.model 42 | 43 | def forward( 44 | self, 45 | input_ids: torch.LongTensor = None, 46 | attention_mask: Optional[torch.Tensor] = None, 47 | position_ids: Optional[torch.LongTensor] = None, 48 | past_key_values: Optional[List[torch.FloatTensor]] = None, 49 | inputs_embeds: Optional[torch.FloatTensor] = None, 50 | labels: Optional[torch.LongTensor] = None, 51 | use_cache: Optional[bool] = None, 52 | output_attentions: Optional[bool] = None, 53 | output_hidden_states: Optional[bool] = None, 54 | images: Optional[torch.FloatTensor] = None, 55 | return_dict: Optional[bool] = None, 56 | ) -> Union[Tuple, CausalLMOutputWithPast]: 57 | 58 | if inputs_embeds is None: 59 | ( 60 | input_ids, 61 | position_ids, 62 | attention_mask, 63 | past_key_values, 64 | inputs_embeds, 65 | labels 66 | ) = self.prepare_inputs_labels_for_multimodal( 67 | input_ids, 68 | position_ids, 69 | attention_mask, 70 | past_key_values, 71 | labels, 72 | images, 73 | 'qwen1.5' 74 | ) 75 | 76 | return super().forward( 77 | input_ids=input_ids, 78 | attention_mask=attention_mask, 79 | position_ids=position_ids, 80 | past_key_values=past_key_values, 81 | inputs_embeds=inputs_embeds, 82 | labels=labels, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 90 | images = kwargs.pop("images", None) 91 | _inputs = super().prepare_inputs_for_generation( 92 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 93 | ) 94 | if images is not None: 95 | _inputs['images'] = images 96 | return _inputs 97 | 98 | AutoConfig.register("imp_qwen2", ImpQwen2Config) 99 | AutoModelForCausalLM.register(ImpQwen2Config, ImpQwen2ForCausalLM) -------------------------------------------------------------------------------- /imp_llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | 28 | 29 | class LlavaConfig(LlamaConfig): 30 | model_type = "llava" 31 | 32 | 33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 34 | config_class = LlavaConfig 35 | 36 | def __init__(self, config: LlamaConfig): 37 | super(LlavaLlamaModel, self).__init__(config) 38 | 39 | 40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaConfig 42 | 43 | def __init__(self, config): 44 | super(LlamaForCausalLM, self).__init__(config) 45 | self.model = LlavaLlamaModel(config) 46 | self.pretraining_tp = config.pretraining_tp 47 | self.vocab_size = config.vocab_size 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> Union[Tuple, CausalLMOutputWithPast]: 70 | 71 | if inputs_embeds is None: 72 | ( 73 | input_ids, 74 | position_ids, 75 | attention_mask, 76 | past_key_values, 77 | inputs_embeds, 78 | labels 79 | ) = self.prepare_inputs_labels_for_multimodal( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | labels, 85 | images 86 | ) 87 | 88 | return super().forward( 89 | input_ids=input_ids, 90 | attention_mask=attention_mask, 91 | position_ids=position_ids, 92 | past_key_values=past_key_values, 93 | inputs_embeds=inputs_embeds, 94 | labels=labels, 95 | use_cache=use_cache, 96 | output_attentions=output_attentions, 97 | output_hidden_states=output_hidden_states, 98 | return_dict=return_dict 99 | ) 100 | 101 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 102 | images = kwargs.pop("images", None) 103 | _inputs = super().prepare_inputs_for_generation( 104 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 105 | ) 106 | if images is not None: 107 | _inputs['images'] = images 108 | return _inputs 109 | 110 | AutoConfig.register("llava", LlavaConfig) 111 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 112 | -------------------------------------------------------------------------------- /imp_llava/model/language_model/qwen2/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "Qwen2ForCausalLM" 4 | ], 5 | "auto_map": { 6 | "AutoConfig": "configuration_qwen2.Qwen2Config", 7 | "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM" 8 | }, 9 | "attention_dropout": 0.0, 10 | "bos_token_id": 151643, 11 | "eos_token_id": 151643, 12 | "hidden_act": "silu", 13 | "hidden_size": 2048, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 5504, 16 | "max_position_embeddings": 32768, 17 | "max_window_layers": 21, 18 | "model_type": "qwen2", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "num_key_value_heads": 16, 22 | "rms_norm_eps": 1e-06, 23 | "rope_theta": 1000000.0, 24 | "sliding_window": 32768, 25 | "tie_word_embeddings": false, 26 | "torch_dtype": "bfloat16", 27 | "transformers_version": "4.37.0", 28 | "use_cache": true, 29 | "use_sliding_window": false, 30 | "vocab_size": 151936 31 | } 32 | -------------------------------------------------------------------------------- /imp_llava/model/language_model/qwen2/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "auto_map": { 4 | "AutoTokenizer": ["tokenization_qwen2.Qwen2Tokenizer"] 5 | }, 6 | "added_tokens_decoder": { 7 | "151643": { 8 | "content": "<|endoftext|>", 9 | "lstrip": false, 10 | "normalized": false, 11 | "rstrip": false, 12 | "single_word": false, 13 | "special": true 14 | }, 15 | "151644": { 16 | "content": "<|im_start|>", 17 | "lstrip": false, 18 | "normalized": false, 19 | "rstrip": false, 20 | "single_word": false, 21 | "special": true 22 | }, 23 | "151645": { 24 | "content": "<|im_end|>", 25 | "lstrip": false, 26 | "normalized": false, 27 | "rstrip": false, 28 | "single_word": false, 29 | "special": true 30 | } 31 | }, 32 | "additional_special_tokens": ["<|im_start|>", "<|im_end|>"], 33 | "bos_token": null, 34 | "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", 35 | "clean_up_tokenization_spaces": false, 36 | "eos_token": "<|endoftext|>", 37 | "errors": "replace", 38 | "model_max_length": 32768, 39 | "pad_token": "<|endoftext|>", 40 | "split_special_tokens": false, 41 | "tokenizer_class": "Qwen2Tokenizer", 42 | "unk_token": null 43 | } 44 | -------------------------------------------------------------------------------- /imp_llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("google"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /imp_llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Zhenwei Shao and MILVLG team. 2 | # Licensed under the Apache License, Version 2.0. 3 | 4 | # Adopted from https://github.com/haotian-liu/LLaVA. 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from typing import Dict, Optional, Union 10 | import numpy as np 11 | 12 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 13 | 14 | from .siglip.image_processing_imp import ImpImageProcessor 15 | from .siglip.modeling_siglip import SiglipVisionModel 16 | from .siglip.configuration_siglip import SiglipVisionConfig 17 | 18 | 19 | class CLIPVisionTower(nn.Module): 20 | def __init__(self, vision_tower, args, delay_load=False): 21 | super().__init__() 22 | 23 | self.is_loaded = False 24 | 25 | self.vision_tower_name = vision_tower 26 | self.select_layer = args.mm_vision_select_layer 27 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 28 | 29 | if not delay_load: 30 | self.load_model() 31 | else: 32 | self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name) 33 | 34 | def load_model(self): 35 | if self.is_loaded: 36 | return 37 | 38 | # It's a hacky way to check if model is initialized under meta device 39 | # context, which will be enabled when loading trained model by huggingface 40 | # `from_pretrained` api. In the case that a full model with vision tower is 41 | # loaded, there will be a warning if vision tower is loaded to cpu here. So we 42 | # set `device_map` to `auto` in order to avoid the warning. 43 | # [Edited by zhenwei - 2024-02-02 13:03] 44 | is_meta = getattr(nn.Linear(1, 1, bias=False).weight, 'is_meta', False) 45 | if 'siglip' in self.vision_tower_name: 46 | # "google/siglip-so400m-patch14-384" 47 | self.image_processor = ImpImageProcessor() 48 | if is_meta: 49 | # cfg = SiglipVisionConfig.from_pretrained(self.vision_tower_name) 50 | # self.vision_tower = SiglipVisionModel(cfg) 51 | self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map='auto') 52 | else: 53 | self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name) 54 | del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):] 55 | self.vision_tower.vision_model.post_layernorm = nn.Identity() 56 | self.vision_tower.vision_model.head = nn.Identity() 57 | else: 58 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 59 | if is_meta: 60 | # cfg = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 61 | # self.vision_tower = CLIPVisionModel(cfg) 62 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map='auto') 63 | else: 64 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 65 | del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):] 66 | 67 | self.vision_tower.requires_grad_(False) 68 | self.vision_tower.eval() 69 | 70 | self.is_loaded = True 71 | 72 | def feature_select(self, image_forward_outs): 73 | # image_features = image_forward_outs.hidden_states[self.select_layer] 74 | image_features = image_forward_outs.hidden_states[-1] 75 | if self.select_feature == 'patch': 76 | image_features = image_features[:, -self.num_patches:] 77 | assert image_features.shape[-2] == self.num_patches, f'select_feature=patch, image_features.shape[-2]={image_features.shape[-2]} != num_patches={self.num_patches}' 78 | elif self.select_feature == 'cls_patch': 79 | image_features = image_features 80 | assert image_features.shape[-2] == self.num_patches + 1, f'select_feature=cls_patch, image_features.shape[-2]={image_features.shape[-2]} != num_patches+1={self.num_patches+1}' 81 | else: 82 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 83 | return image_features 84 | 85 | @torch.no_grad() 86 | def forward(self, images): 87 | # assert self.num_patches == 729 88 | if type(images) is list: 89 | image_features = [] 90 | for image in images: 91 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 92 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 93 | # image_feature = image_forward_out.last_hidden_state.to(image.dtype) 94 | image_features.append(image_feature) 95 | else: 96 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 97 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 98 | # image_features = image_forward_outs.last_hidden_state.to(images.dtype) 99 | 100 | return image_features 101 | 102 | @property 103 | def dummy_feature(self): 104 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 105 | 106 | @property 107 | def dtype(self): 108 | for p in self.vision_tower.parameters(): 109 | return p.dtype 110 | 111 | @property 112 | def device(self): 113 | for p in self.vision_tower.parameters(): 114 | return p.device 115 | 116 | @property 117 | def is_meta(self): 118 | return self.device.type == 'meta' 119 | 120 | @property 121 | def config(self): 122 | if self.is_loaded: 123 | return self.vision_tower.config 124 | else: 125 | return self.cfg_only 126 | 127 | @property 128 | def hidden_size(self): 129 | return self.config.hidden_size 130 | 131 | @property 132 | def num_patches(self): 133 | return (self.config.image_size // self.config.patch_size) ** 2 134 | -------------------------------------------------------------------------------- /imp_llava/model/multimodal_encoder/siglip/configuration_siglip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 Zhenwei Shao and MILVLG team. 3 | # Licensed under the Apache License, Version 2.0. 4 | 5 | # Adopted from https://github.com/huggingface/transformers/tree/main/src/transformers/models/siglip. 6 | # Below is the original copyright: 7 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | """ Siglip model configuration""" 21 | 22 | import os 23 | from typing import Union 24 | 25 | from transformers.configuration_utils import PretrainedConfig 26 | 27 | from transformers.utils import logging 28 | 29 | logger = logging.get_logger(__name__) 30 | 31 | 32 | class SiglipVisionConfig(PretrainedConfig): 33 | 34 | model_type = "siglip_vision_model" 35 | 36 | def __init__( 37 | self, 38 | hidden_size=768, 39 | intermediate_size=3072, 40 | num_hidden_layers=12, 41 | num_attention_heads=12, 42 | num_channels=3, 43 | image_size=224, 44 | patch_size=16, 45 | hidden_act="gelu_pytorch_tanh", 46 | layer_norm_eps=1e-6, 47 | attention_dropout=0.0, 48 | **kwargs, 49 | ): 50 | super().__init__(**kwargs) 51 | 52 | self.hidden_size = hidden_size 53 | self.intermediate_size = intermediate_size 54 | self.num_hidden_layers = num_hidden_layers 55 | self.num_attention_heads = num_attention_heads 56 | self.num_channels = num_channels 57 | self.patch_size = patch_size 58 | self.image_size = image_size 59 | self.attention_dropout = attention_dropout 60 | self.layer_norm_eps = layer_norm_eps 61 | self.hidden_act = hidden_act 62 | 63 | @classmethod 64 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 65 | cls._set_token_in_kwargs(kwargs) 66 | 67 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 68 | 69 | # get the vision config dict if we are loading from SiglipConfig 70 | if config_dict.get("model_type") == "siglip": 71 | config_dict = config_dict["vision_config"] 72 | 73 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 74 | logger.warning( 75 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 76 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 77 | ) 78 | 79 | return cls.from_dict(config_dict, **kwargs) 80 | -------------------------------------------------------------------------------- /imp_llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /imp_llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /imp_llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /imp_llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | import os, sys 5 | sys.path.append('./') 6 | 7 | # Need to call this before importing transformers. 8 | # from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 9 | 10 | # replace_llama_attn_with_flash_attn() 11 | 12 | from imp_llava.train.train import train 13 | 14 | if __name__ == "__main__": 15 | train() 16 | -------------------------------------------------------------------------------- /imp_llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from imp_llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from imp_llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /imp_llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from imp_llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /logs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/logs/.gitkeep -------------------------------------------------------------------------------- /playground/data/eval/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/playground/data/eval/.gitkeep -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/000_caps.txt: -------------------------------------------------------------------------------- 1 | A man wearing multiple neck ties making a goofy face. 2 | A man in a white shirt wearing very many ties. 3 | a man with ties on poses for a picture 4 | A man wearing multiple ties on his neck. 5 | A young man smiles while wearing several ties. 6 | 7 | tie: [0.574, 0.298, 0.752, 0.704] 8 | tie: [0.464, 0.339, 0.639, 0.789] 9 | tie: [0.349, 0.363, 0.563, 0.732] 10 | tie: [0.259, 0.255, 0.668, 0.805] 11 | person: [0.019, 0.065, 0.962, 0.988] 12 | person: [0.0, 0.24, 0.214, 1.0] 13 | tie: [0.316, 0.778, 0.443, 0.867] 14 | tie: [0.386, 0.707, 0.496, 0.801] 15 | tie: [0.251, 0.354, 0.402, 0.747] 16 | tie: [0.44, 0.318, 0.63, 0.397] 17 | tie: [0.566, 0.33, 0.655, 0.543] 18 | tie: [0.25, 0.409, 0.359, 0.556] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/000_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What is unusual about this photo? 3 | === 4 | Answer: 5 | In the photo, the man is wearing a total of ten ties around his neck. This is considered unusual because it is not common for someone to wear multiple ties at once. In a typical setting, a person would wear only one tie. The man's goofy facial expression also adds to the unconventional and humorous nature of the image. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/001_caps.txt: -------------------------------------------------------------------------------- 1 | A group of people standing outside of a black vehicle with various luggage. 2 | Luggage surrounds a vehicle in an underground parking area 3 | People try to fit all of their luggage in an SUV. 4 | The sport utility vehicle is parked in the public garage, being packed for a trip 5 | Some people with luggage near a van that is transporting it. 6 | 7 | person: [0.681, 0.242, 0.774, 0.694] 8 | person: [0.63, 0.222, 0.686, 0.516] 9 | person: [0.444, 0.233, 0.487, 0.34] 10 | backpack: [0.384, 0.696, 0.485, 0.914] 11 | backpack: [0.755, 0.413, 0.846, 0.692] 12 | suitcase: [0.758, 0.413, 0.845, 0.69] 13 | suitcase: [0.1, 0.497, 0.173, 0.579] 14 | bicycle: [0.282, 0.363, 0.327, 0.442] 15 | car: [0.786, 0.25, 0.848, 0.322] 16 | car: [0.783, 0.27, 0.827, 0.335] 17 | car: [0.86, 0.254, 0.891, 0.3] 18 | car: [0.261, 0.101, 0.787, 0.626] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/001_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What challenges do these people face? 3 | === 4 | Answer: 5 | In the image, a group of people is standing outside a black SUV in a parking area, surrounded by various pieces of luggage, including suitcases and backpacks. They are facing the challenge of fitting all their luggage into the black SUV. There are multiple suitcases and backpacks to be packed, which suggests that the group has a significant amount of belongings to accommodate. They might have to strategize and arrange the luggage efficiently to ensure that everything fits properly into the vehicle. Additionally, they need to consider the comfort of the passengers and visibility while driving, so the placement of the luggage must not obstruct the driver's view or make the passengers uncomfortable during the trip. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/002_caps.txt: -------------------------------------------------------------------------------- 1 | There is a movie theater that displays the show times above the doors. 2 | A red fire hydrant is deep in the snow. 3 | The fire hydrant is in the snow near a recently plowed sidewalk. 4 | This city has had a very hard winter with snow. 5 | A hotel for dogs in the snow in winter. 6 | 7 | fire hydrant: [0.326, 0.612, 0.426, 0.72] -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/002_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What challenges might this city face? 3 | === 4 | Answer: 5 | The city faces challenges due to the harsh winter conditions and heavy snowfall. In the image, a red fire hydrant is almost buried deep in the snow, which indicates the significant amount of snow the city has experienced. This can lead to various challenges such as difficulties in transportation, increased risk of accidents, and disruptions to daily life. For example, the recently plowed sidewalk near the fire hydrant shows that the city has to constantly clear snow from roads and sidewalks to maintain access and safety for pedestrians and vehicles. Moreover, emergency services, like firefighters, might face challenges accessing crucial equipment, such as fire hydrants, during emergencies due to the snow accumulation. This highlights the importance of effective snow management strategies and preparedness in such cities to minimize the impact of harsh winter conditions on residents and essential services. -------------------------------------------------------------------------------- /playground/data/prompts/complex_reasoning/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant that can analyze a single image. You receive five sentences, each describing the same image you are observing. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. 2 | 3 | The task is to use the provided caption and bounding box information, create a plausible question about the image, and provide the answer in detail. 4 | 5 | Create complex questions beyond describing the scene. 6 | To answer such questions, one should require first understanding the visual content, then based on the background knowledge or reasoning, either explain why the things are happening that way, or provide guides and help to user's request. Make the question challenging by not including the visual content details in the question so that the user needs to reason about that first. 7 | 8 | Instead of directly mentioning the bounding box coordinates, utilize this data to explain the scene using natural language. Include details like object counts, position of the objects, relative position between the objects. 9 | 10 | When using the information from the caption and coordinates, directly explain the scene, and do not mention that the information source is the caption or the bounding box. Always answer as if you are directly looking at the image. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/000_caps.txt: -------------------------------------------------------------------------------- 1 | There is a movie theater that displays the show times above the doors. 2 | A red fire hydrant is deep in the snow. 3 | The fire hydrant is in the snow near a recently plowed sidewalk. 4 | This city has had a very hard winter with snow. 5 | A hotel for dogs in the snow in winter. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/000_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What color is the fire hydrant in the image? 3 | === 4 | Answer: 5 | The fire hydrant in the image is red. 6 | === 7 | Question: 8 | Is the fire hydrant fully covered by snow or partially visible? 9 | === 10 | Answer: 11 | The fire hydrant is partially visible, with a significant portion of it covered by the deep snow. 12 | === 13 | Question: 14 | What is the condition of the sidewalk near the fire hydrant? 15 | === 16 | Answer: 17 | The sidewalk near the fire hydrant has been recently plowed, so it is relatively clear of snow compared to the surrounding area. 18 | === 19 | Question: 20 | What does the image suggest about the winter season in the city? 21 | === 22 | Answer: 23 | The image suggests that the winter season in the city has been quite harsh and challenging. The significant snow accumulation seen in the image indicates that the city has experienced heavy snowfall, which can lead to various difficulties and disruptions for its residents. 24 | 25 | For instance, the deep snow covering the fire hydrant could potentially hinder its accessibility in case of emergencies, which poses safety concerns. The recently plowed sidewalk implies that the city's maintenance crews have been actively working to keep the walkways clear and safe for pedestrians, but the sheer amount of snowfall might make it difficult to keep up with the continuous snow removal efforts. 26 | 27 | Furthermore, such extreme winter conditions can impact transportation, leading to traffic issues, delays in public transit, and increased risks of accidents due to slippery roads. It can also cause problems for businesses and schools, as people might face challenges commuting to work or attending classes. Additionally, the heavy snow can put extra strain on infrastructure, such as roofs and power lines, increasing the likelihood of structural damage or power outages. 28 | 29 | In conclusion, the image of the red fire hydrant deep in the snow and the recently plowed sidewalk suggest that the city has faced a particularly severe winter season, with substantial snowfall that has likely caused various challenges and disruptions for its residents and infrastructure. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/001_caps.txt: -------------------------------------------------------------------------------- 1 | A man is skiing in the open snow covered hills 2 | A skier is making his way into the snow. 3 | A skier on the foothills of a huge mountain range. 4 | A skier looks at mountains while standing near a trail sign. 5 | a single person skiing by an area with a lot of bushes -------------------------------------------------------------------------------- /playground/data/prompts/conversation/001_conv.txt: -------------------------------------------------------------------------------- 1 | Question: 2 | What is the skier doing? 3 | === 4 | Answer: 5 | The skier is skiing in the snow-covered hills, making his way through the snow and enjoying the beautiful mountain scenery. They are also standing near a trail sign, which indicates their location on the mountain. 6 | === 7 | Question: 8 | Are there any other people or skiers visible in the image? 9 | === 10 | Answer: 11 | No, there is only a single person skiing in the image, and they are surrounded by an area with a lot of bushes. 12 | === 13 | Question: 14 | What kind of skiing is the person doing? Cross-country or downhill? 15 | === 16 | Answer: 17 | Based on what I can see in the photo, it is not possible to definitively determine whether the skier is doing cross-country or downhill skiing. However, given the presence of a trail sign and the fact that they are skiing in the foothills, it is more likely that the skier is engaged in cross-country skiing. Downhill skiing typically takes place on more steep slopes and involves ski lifts, whereas cross-country skiing often occurs on relatively flatter terrain with marked trails. 18 | === 19 | Question: 20 | Can you explain cross-country and downhill skiing a bit more? 21 | === 22 | Answer: 23 | Cross-country skiing and downhill skiing are two distinct types of skiing, each with its own techniques, equipment, and terrain preferences. 24 | 25 | 1. Cross-Country Skiing: 26 | Also known as Nordic skiing, cross-country skiing involves traversing relatively flat terrain or gentle slopes using skis and poles. It is a popular recreational and competitive sport that focuses on endurance, balance, and aerobic fitness. Cross-country skiers use their own locomotion to move across snow-covered terrain, rather than relying on ski lifts or other forms of assistance. 27 | There are two main techniques in cross-country skiing: 28 | - Classic: This technique involves skiing in parallel tracks, using a forward-and-backward gliding motion while keeping the skis parallel to each other. The skier uses their poles to propel themselves forward. 29 | - Skate skiing: This technique is similar to ice skating, where the skier pushes off one ski at an angle and glides on the other. It requires a smooth, groomed surface and is typically faster than the classic technique. 30 | Cross-country ski equipment consists of lightweight skis, boots, bindings, and poles. The boots are more flexible than those used in downhill skiing, allowing for better ankle movement and control. 31 | 32 | 2. Downhill Skiing: 33 | Also known as alpine skiing, downhill skiing involves descending slopes at high speeds using skis and poles for balance and control. This sport is more focused on speed, technique, and navigating challenging terrain, including steep slopes, moguls, and even jumps. 34 | Downhill skiing can be further categorized into several disciplines, such as slalom, giant slalom, super-G, and downhill racing. Each discipline has its own set of rules, courses, and ski equipment. 35 | Downhill ski equipment includes heavier and stiffer skis, boots, bindings, and poles than those used in cross-country skiing. The boots are more rigid to provide better support and control during high-speed descents and sharp turns. 36 | 37 | In summary, cross-country skiing is an endurance-based sport that involves traveling across flat or gently sloping terrain, while downhill skiing is focused on speed and technique as skiers navigate steeper slopes and challenging terrain. Both sports require specialized equipment and techniques, but they offer different experiences and challenges to participants. -------------------------------------------------------------------------------- /playground/data/prompts/conversation/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant, and you are seeing a single image. What you see are provided with five sentences, describing the same image you are looking at. Answer all questions as you are seeing the image. 2 | 3 | Design a conversation between you and a person asking about this photo. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question. 4 | Ask diverse questions and give corresponding answers. 5 | 6 | Include questions asking about the visual content of the image, including the object types, counting the objects, object actions, object locations, relative positions between objects, etc. Only include questions that have definite answers: 7 | (1) one can see the content in the image that the question asks about and can answer confidently; 8 | (2) one can determine confidently from the image that it is not in the image. 9 | Do not ask any question that cannot be answered confidently. 10 | 11 | Also include complex questions that are relevant to the content in the image, for example, asking about background knowledge of the objects in the image, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details. 12 | Provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/000_caps.txt: -------------------------------------------------------------------------------- 1 | A harbor filled with lots of boats next to a building. 2 | A bicycle parked in front of several boats at a dock. 3 | A red bicycle in front of a line of docked white yachts 4 | A bike sits before boats which sit before a long building. 5 | A bicycle is a convenient means of land transportation when you live on a boat. 6 | 7 | bicycle: [0.287, 0.641, 0.507, 0.874] 8 | bicycle: [0.566, 0.667, 0.63, 0.731] 9 | boat: [0.318, 0.579, 0.575, 0.724] 10 | boat: [0.704, 0.607, 0.818, 0.727] 11 | boat: [0.818, 0.601, 0.942, 0.744] 12 | boat: [0.002, 0.53, 0.243, 0.71] 13 | boat: [0.541, 0.611, 0.668, 0.731] 14 | person: [0.778, 0.527, 0.797, 0.57] 15 | cup: [0.708, 0.733, 0.724, 0.758] 16 | boat: [0.236, 0.532, 0.404, 0.64] 17 | boat: [0.81, 0.632, 0.836, 0.676] 18 | boat: [0.957, 0.526, 1.0, 0.752] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/000_conv.txt: -------------------------------------------------------------------------------- 1 | It is a harbor filled with numerous boats of various sizes docked next to a long building. Among the boats, there are a few white yachts lined up, standing out from the rest. There is a red bicycle prominently parked in front of the line of docked boats, serving as a convenient means of land transportation for those living on the boats. Another bicycle can be seen further back in the scene, near the middle of the harbor. 2 | 3 | A person is visible near the right side of the harbor, possibly enjoying the view or attending to their boat. Additionally, there is a cup placed on a surface near the middle of the scene. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/001_caps.txt: -------------------------------------------------------------------------------- 1 | A group of people standing outside of a black vehicle with various luggage. 2 | Luggage surrounds a vehicle in an underground parking area 3 | People try to fit all of their luggage in an SUV. 4 | The sport utility vehicle is parked in the public garage, being packed for a trip 5 | Some people with luggage near a van that is transporting it. 6 | 7 | person: [0.681, 0.242, 0.774, 0.694] 8 | person: [0.63, 0.222, 0.686, 0.516] 9 | person: [0.444, 0.233, 0.487, 0.34] 10 | backpack: [0.384, 0.696, 0.485, 0.914] 11 | backpack: [0.755, 0.413, 0.846, 0.692] 12 | suitcase: [0.758, 0.413, 0.845, 0.69] 13 | suitcase: [0.1, 0.497, 0.173, 0.579] 14 | bicycle: [0.282, 0.363, 0.327, 0.442] 15 | car: [0.786, 0.25, 0.848, 0.322] 16 | car: [0.783, 0.27, 0.827, 0.335] 17 | car: [0.86, 0.254, 0.891, 0.3] 18 | car: [0.261, 0.101, 0.787, 0.626] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/001_conv.txt: -------------------------------------------------------------------------------- 1 | The image is an underground parking area with a black sport utility vehicle (SUV) parked. There are three people in the scene, with one person standing closer to the left side of the vehicle, another person in the middle, and the third person on the right side. They are all working together to pack their luggage into the SUV for a trip. 2 | 3 | In the parking area, various luggage items are scattered around the vehicle. There are two backpacks, one located near the left rear wheel and the other closer to the right side of the vehicle. Additionally, there are two suitcases, one on the right side of the car and another further away near the center of the parking area. A bicycle can also be seen on the left side of the vehicle. 4 | 5 | Other cars are parked around the main SUV, with one car positioned behind it and slightly to the left, another behind and slightly to the right, and the third car further behind on the right side. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/002_caps.txt: -------------------------------------------------------------------------------- 1 | A man holds a Wii-mote above his head while another looks on. 2 | A guy and his friend are playing Nintendo Wii. 3 | A young man is holding a video game remote over his head. 4 | two men standing in a room while one plays with a wii mote 5 | Some guys standing and playing a video game. 6 | 7 | couch: [0.697, 0.759, 0.995, 1.0] 8 | dining table: [0.426, 0.755, 1.0, 0.987] 9 | person: [0.082, 0.252, 0.342, 1.0] 10 | person: [0.399, 0.085, 0.742, 0.982] 11 | remote: [0.477, 0.135, 0.516, 0.187] 12 | sink: [0.016, 0.501, 0.063, 0.52] 13 | potted plant: [0.798, 0.384, 0.888, 0.645] 14 | refrigerator: [0.305, 0.389, 0.414, 0.547] 15 | chair: [0.72, 0.509, 0.858, 0.725] -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/002_conv.txt: -------------------------------------------------------------------------------- 1 | The image shows two men standing in a room, engaged in playing a video game on a Nintendo Wii console. One of the men is holding a Wii remote above his head with enthusiasm, while the other man looks on, likely enjoying the friendly competition. 2 | 3 | The room appears to be a living space with a couch located in the background and a dining table nearby. A potted plant can be seen placed close to the couch, and a chair is situated in the middle of the room. The room also features a kitchen area with a sink and a refrigerator visible in the background. -------------------------------------------------------------------------------- /playground/data/prompts/detail_description/system_message.txt: -------------------------------------------------------------------------------- 1 | You are an AI visual assistant that can analyze a single image. You receive five sentences, each describing the same image you are observing. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. 2 | 3 | Using the provided caption and bounding box information, describe the scene in a detailed manner. 4 | 5 | Instead of directly mentioning the bounding box coordinates, utilize this data to explain the scene using natural language. Include details like object counts, position of the objects, relative position between the objects. 6 | 7 | When using the information from the caption and coordinates, directly explain the scene, and do not mention that the information source is the caption or the bounding box. Always answer as if you are directly looking at the image. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | transformers==4.36.0 3 | sentencepiece==0.1.99 4 | accelerate==0.21.0 5 | peft==0.4.0 6 | bitsandbytes==0.41.0 7 | scikit-learn==1.2.2 8 | einops==0.6.1 9 | deepspeed==0.9.5 10 | pillow 11 | shortuuid 12 | -------------------------------------------------------------------------------- /scripts/convert_gqa_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | all_answers = [] 11 | for line_idx, line in enumerate(open(args.src)): 12 | res = json.loads(line) 13 | question_id = res['question_id'] 14 | text = res['text'].rstrip('.').lower() 15 | all_answers.append({"questionId": question_id, "prediction": text}) 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(all_answers, f) 19 | -------------------------------------------------------------------------------- /scripts/convert_mmbench_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str, required=True) 9 | parser.add_argument("--result-dir", type=str, required=True) 10 | parser.add_argument("--upload-dir", type=str, required=True) 11 | parser.add_argument("--experiment", type=str, required=True) 12 | 13 | return parser.parse_args() 14 | 15 | if __name__ == "__main__": 16 | args = get_args() 17 | 18 | df = pd.read_table(args.annotation_file) 19 | 20 | cur_df = df.copy() 21 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) 22 | cur_df.insert(6, 'prediction', None) 23 | for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")): 24 | pred = json.loads(pred) 25 | cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text'] 26 | 27 | cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl') 28 | -------------------------------------------------------------------------------- /scripts/convert_mmvet_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | cur_result = {} 11 | 12 | for line in open(args.src): 13 | data = json.loads(line) 14 | qid = data['question_id'] 15 | cur_result[f'v1_{qid}'] = data['text'] 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(cur_result, f, indent=2) 19 | -------------------------------------------------------------------------------- /scripts/convert_seed_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str) 9 | parser.add_argument("--result-file", type=str) 10 | parser.add_argument("--result-upload-file", type=str) 11 | return parser.parse_args() 12 | 13 | 14 | def eval_single(result_file, eval_only_type=None): 15 | results = {} 16 | for line in open(result_file): 17 | row = json.loads(line) 18 | results[row['question_id']] = row 19 | 20 | type_counts = {} 21 | correct_counts = {} 22 | for question_data in data['questions']: 23 | if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue 24 | data_type = question_data['question_type_id'] 25 | type_counts[data_type] = type_counts.get(data_type, 0) + 1 26 | try: 27 | question_id = int(question_data['question_id']) 28 | except: 29 | question_id = question_data['question_id'] 30 | if question_id not in results: 31 | correct_counts[data_type] = correct_counts.get(data_type, 0) 32 | continue 33 | row = results[question_id] 34 | if row['text'] == question_data['answer']: 35 | correct_counts[data_type] = correct_counts.get(data_type, 0) + 1 36 | 37 | total_count = 0 38 | total_correct = 0 39 | for data_type in sorted(type_counts.keys()): 40 | accuracy = correct_counts[data_type] / type_counts[data_type] * 100 41 | if eval_only_type is None: 42 | print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%") 43 | 44 | total_count += type_counts[data_type] 45 | total_correct += correct_counts[data_type] 46 | 47 | total_accuracy = total_correct / total_count * 100 48 | if eval_only_type is None: 49 | print(f"Total accuracy: {total_accuracy:.2f}%") 50 | else: 51 | print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%") 52 | 53 | return results 54 | 55 | if __name__ == "__main__": 56 | args = get_args() 57 | data = json.load(open(args.annotation_file)) 58 | ques_type_id_to_name = {id:n for n,id in data['question_type'].items()} 59 | 60 | results = eval_single(args.result_file) 61 | eval_single(args.result_file, eval_only_type='image') 62 | eval_single(args.result_file, eval_only_type='video') 63 | 64 | with open(args.result_upload_file, 'w') as fp: 65 | for question in data['questions']: 66 | qid = question['question_id'] 67 | if qid in results: 68 | result = results[qid] 69 | else: 70 | result = results[int(qid)] 71 | fp.write(json.dumps({ 72 | 'question_id': qid, 73 | 'prediction': result['text'] 74 | }) + '\n') 75 | -------------------------------------------------------------------------------- /scripts/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /scripts/convert_vizwiz_for_submission.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('./') 3 | import argparse 4 | import json 5 | 6 | from imp_llava.eval.m4c_evaluator import EvalAIAnswerProcessor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str, required=True) 12 | parser.add_argument('--result-file', type=str, required=True) 13 | parser.add_argument('--result-upload-file', type=str, required=True) 14 | return parser.parse_args() 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | args = parse_args() 20 | 21 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) 22 | 23 | results = [] 24 | error_line = 0 25 | for line_idx, line in enumerate(open(args.result_file)): 26 | try: 27 | results.append(json.loads(line)) 28 | except: 29 | error_line += 1 30 | results = {x['question_id']: x['text'] for x in results} 31 | test_split = [json.loads(line) for line in open(args.annotation_file)] 32 | split_ids = set([x['question_id'] for x in test_split]) 33 | 34 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 35 | 36 | all_answers = [] 37 | 38 | answer_processor = EvalAIAnswerProcessor() 39 | 40 | for x in test_split: 41 | assert x['question_id'] in results 42 | all_answers.append({ 43 | 'image': x['image'], 44 | 'answer': answer_processor(results[x['question_id']]) 45 | }) 46 | 47 | with open(args.result_upload_file, 'w') as f: 48 | json.dump(all_answers, f) 49 | -------------------------------------------------------------------------------- /scripts/convert_vqav2_for_submission.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | sys.path.append('./') 3 | import argparse 4 | import json 5 | 6 | from imp_llava.eval.m4c_evaluator import EvalAIAnswerProcessor 7 | 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2") 12 | parser.add_argument('--ckpt', type=str, required=True) 13 | parser.add_argument('--split', type=str, required=True) 14 | return parser.parse_args() 15 | 16 | 17 | if __name__ == '__main__': 18 | 19 | args = parse_args() 20 | 21 | src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl') 22 | test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl') 23 | dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json') 24 | os.makedirs(os.path.dirname(dst), exist_ok=True) 25 | 26 | results = [] 27 | error_line = 0 28 | for line_idx, line in enumerate(open(src)): 29 | try: 30 | results.append(json.loads(line)) 31 | except: 32 | error_line += 1 33 | 34 | results = {x['question_id']: x['text'] for x in results} 35 | test_split = [json.loads(line) for line in open(test_split)] 36 | split_ids = set([x['question_id'] for x in test_split]) 37 | 38 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 39 | 40 | all_answers = [] 41 | 42 | answer_processor = EvalAIAnswerProcessor() 43 | 44 | for x in test_split: 45 | if x['question_id'] not in results: 46 | all_answers.append({ 47 | 'question_id': x['question_id'], 48 | 'answer': '' 49 | }) 50 | else: 51 | all_answers.append({ 52 | 'question_id': x['question_id'], 53 | 'answer': answer_processor(results[x['question_id']]) 54 | }) 55 | 56 | with open(dst, 'w') as f: 57 | json.dump(all_answers, open(dst, 'w')) 58 | -------------------------------------------------------------------------------- /scripts/download_models.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["https_proxy"] = "http://xxx.xxx.xxx.xxx:xx" # in case you need proxy to access Huggingface Hub 3 | from huggingface_hub import snapshot_download 4 | 5 | snapshot_download( 6 | repo_id="microsoft/phi-2", 7 | revision="d3186761bf5c4409f7679359284066c25ab668ee", 8 | local_dir='/home/ouyangxc/ckpts/base/phi-2', 9 | local_dir_use_symlinks=False 10 | ) 11 | 12 | snapshot_download( 13 | repo_id="google/siglip-so400m-patch14-384", 14 | revision="7067f6db2baa594bab7c6d965fe488c7ac62f1c8", 15 | local_dir='/home/ouyangxc/ckpts/base/siglip-so400m-patch14-384', 16 | local_dir_use_symlinks=False 17 | ) -------------------------------------------------------------------------------- /scripts/eval/gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | GQADIR="./playground/data/eval/gqa/data" 13 | 14 | 15 | SPLIT="llava_gqa_testdev_balanced" 16 | 17 | # # merge eval 18 | # MODEL_CKPT="milvlg/imp-v1-3b" 19 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 20 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 21 | # # MODEL_PATH=$MODEL_CKPT 22 | # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 23 | 24 | # for IDX in $(seq 0 $((CHUNKS-1))); do 25 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 26 | # --model-path $MODEL_PATH \ 27 | # --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 28 | # --image-folder /path/to/images \ 29 | # --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 30 | # --num-chunks $CHUNKS \ 31 | # --chunk-idx $IDX \ 32 | # --temperature 0 \ 33 | # --conv-mode phi2 & 34 | # done 35 | 36 | # wait 37 | 38 | # lora eval 39 | MODEL_CKPT="imp-v1-3b-phi2-stage2_lora" 40 | # MODEL_CKPT="llava-phi2-lora-0427-1005_withocr" 41 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 42 | MODEL_BASE=/data/llm_common/phi-2 43 | 44 | for IDX in $(seq 0 $((CHUNKS-1))); do 45 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 46 | --model-path ./checkpoints/$MODEL_CKPT \ 47 | --model-base $MODEL_BASE \ 48 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 49 | --image-folder /path/to/images \ 50 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 51 | --num-chunks $CHUNKS \ 52 | --chunk-idx $IDX \ 53 | --temperature 0 \ 54 | --conv-mode phi2 & 55 | done 56 | 57 | wait 58 | 59 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 60 | 61 | # Clear out the output file if it exists. 62 | > "$output_file" 63 | 64 | # Loop through the indices and concatenate each file. 65 | for IDX in $(seq 0 $((CHUNKS-1))); do 66 | cat ./playground/data/eval/gqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 67 | done 68 | 69 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json 70 | 71 | cd $GQADIR 72 | python eval/eval.py --tier testdev_balanced 73 | -------------------------------------------------------------------------------- /scripts/eval/mmbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="mmbench_dev" 14 | 15 | # # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | # for IDX in $(seq 0 $((CHUNKS-1))); do 23 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_mmbench \ 24 | # --model-path $MODEL_PATH \ 25 | # --question-file ./playground/data/eval/mmbench/mmbench_dev_20230712.tsv \ 26 | # --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 27 | # --num-chunks $CHUNKS \ 28 | # --chunk-idx $IDX \ 29 | # --temperature 0 \ 30 | # --conv-mode phi2 & 31 | # done 32 | 33 | # wait 34 | 35 | 36 | # lora eval 37 | MODEL_CKPT="imp-v1-3b-stage2-lora" 38 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 39 | MODEL_BASE=checkpoints/base/phi-2 40 | 41 | for IDX in $(seq 0 $((CHUNKS-1))); do 42 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_mmbench \ 43 | --model-path ./checkpoints/$MODEL_CKPT \ 44 | --model-base $MODEL_BASE \ 45 | --question-file ./playground/data/eval/mmbench/mmbench_dev_20230712.tsv \ 46 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 47 | --num-chunks $CHUNKS \ 48 | --chunk-idx $IDX \ 49 | --temperature 0 \ 50 | --conv-mode phi2 & 51 | done 52 | 53 | wait 54 | 55 | 56 | output_file=./playground/data/eval/mmbench/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 57 | 58 | # Clear out the output file if it exists. 59 | > "$output_file" 60 | 61 | # Loop through the indices and concatenate each file. 62 | for IDX in $(seq 0 $((CHUNKS-1))); do 63 | cat ./playground/data/eval/mmbench/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 64 | done 65 | 66 | mkdir -p ./playground/data/eval/mmbench/answers_upload/$SPLIT/$EVAL_CKPT 67 | 68 | python scripts/convert_mmbench_for_submission.py \ 69 | --annotation-file ./playground/data/eval/mmbench/mmbench_dev_20230712.tsv \ 70 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT/$EVAL_CKPT \ 71 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT/$EVAL_CKPT \ 72 | --experiment merge 73 | 74 | 75 | -------------------------------------------------------------------------------- /scripts/eval/mme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_MME" 14 | 15 | # # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | # for IDX in $(seq 0 $((CHUNKS-1))); do 23 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 24 | # --model-path $MODEL_PATH \ 25 | # --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 26 | # --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 27 | # --answers-file ./playground/data/eval/MME/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 28 | # --num-chunks $CHUNKS \ 29 | # --chunk-idx $IDX \ 30 | # --temperature 0 \ 31 | # --conv-mode phi2 & 32 | # done 33 | 34 | # wait 35 | 36 | 37 | # lora eval 38 | MODEL_CKPT="imp-v1-3b-stage2-lora" 39 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 40 | MODEL_BASE=checkpoints/base/phi-2 41 | 42 | for IDX in $(seq 0 $((CHUNKS-1))); do 43 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 44 | --model-path ./checkpoints/$MODEL_CKPT \ 45 | --model-base $MODEL_BASE \ 46 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 47 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 48 | --answers-file ./playground/data/eval/MME/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 49 | --num-chunks $CHUNKS \ 50 | --chunk-idx $IDX \ 51 | --temperature 0 \ 52 | --conv-mode phi2 & 53 | done 54 | 55 | wait 56 | 57 | output_file=./playground/data/eval/MME/answers/$SPLIT.jsonl 58 | 59 | # Clear out the output file if it exists. 60 | > "$output_file" 61 | 62 | # Loop through the indices and concatenate each file. 63 | for IDX in $(seq 0 $((CHUNKS-1))); do 64 | cat ./playground/data/eval/MME/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 65 | done 66 | 67 | cd ./playground/data/eval/MME 68 | 69 | python convert_answer_to_mme.py --experiment $SPLIT 70 | 71 | cd eval_tool 72 | 73 | python calculation.py --results_dir answers/$SPLIT 74 | -------------------------------------------------------------------------------- /scripts/eval/mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_mmvet" 14 | 15 | # # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | # for IDX in $(seq 0 $((CHUNKS-1))); do 23 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa \ 24 | # --model-path $MODEL_PATH \ 25 | # --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 26 | # --image-folder ./playground/data/eval/mm-vet/images \ 27 | # --answers-file ./playground/data/eval/mm-vet/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 28 | # --num-chunks $CHUNKS \ 29 | # --chunk-idx $IDX \ 30 | # --temperature 0 \ 31 | # --conv-mode phi2 & 32 | 33 | # done 34 | 35 | # wait 36 | 37 | # lora eval 38 | MODEL_CKPT="imp-v1-3b-stage2-lora" 39 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 40 | MODEL_BASE=checkpoints/base/phi-2 41 | 42 | for IDX in $(seq 0 $((CHUNKS-1))); do 43 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 44 | --model-path ./checkpoints/$MODEL_CKPT \ 45 | --model-base $MODEL_BASE \ 46 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 47 | --image-folder ./playground/data/eval/mm-vet/images \ 48 | --answers-file ./playground/data/eval/mm-vet/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 49 | --num-chunks $CHUNKS \ 50 | --chunk-idx $IDX \ 51 | --temperature 0 \ 52 | --conv-mode phi2 & 53 | 54 | done 55 | 56 | wait 57 | 58 | output_file=./playground/data/eval/mm-vet/answers/$SPLIT.jsonl 59 | 60 | # Clear out the output file if it exists. 61 | > "$output_file" 62 | 63 | # Loop through the indices and concatenate each file. 64 | for IDX in $(seq 0 $((CHUNKS-1))); do 65 | cat ./playground/data/eval/mm-vet/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 66 | done 67 | 68 | mkdir -p ./playground/data/eval/mm-vet/results 69 | 70 | python scripts/convert_mmvet_for_eval.py \ 71 | --src $output_file \ 72 | --dst ./playground/data/eval/mm-vet/results/llava-phi_$EVAL_CKPT.json 73 | 74 | -------------------------------------------------------------------------------- /scripts/eval/pope.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_pope" 14 | 15 | # # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | # for IDX in $(seq 0 $((CHUNKS-1))); do 23 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 24 | # --model-path $MODEL_PATH \ 25 | # --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 26 | # --image-folder ./playground/data/eval/pope/val2014 \ 27 | # --answers-file ./playground/data/eval/pope/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 28 | # --num-chunks $CHUNKS \ 29 | # --chunk-idx $IDX \ 30 | # --temperature 0 \ 31 | # --conv-mode phi2 #& 32 | # done 33 | 34 | # wait 35 | 36 | lora eval 37 | MODEL_CKPT="imp-v1-3b-stage2-lora" 38 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 39 | MODEL_BASE=checkpoints/base/phi-2 40 | 41 | for IDX in $(seq 0 $((CHUNKS-1))); do 42 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 43 | --model-path ./checkpoints/$MODEL_CKPT \ 44 | --model-base $MODEL_BASE \ 45 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 46 | --image-folder ./playground/data/eval/pope/val2014 \ 47 | --answers-file ./playground/data/eval/pope/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 48 | --num-chunks $CHUNKS \ 49 | --chunk-idx $IDX \ 50 | --temperature 0 \ 51 | --conv-mode phi2 & 52 | done 53 | 54 | wait 55 | 56 | output_file=./playground/data/eval/pope/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 57 | 58 | # Clear out the output file if it exists. 59 | > "$output_file" 60 | 61 | # Loop through the indices and concatenate each file. 62 | for IDX in $(seq 0 $((CHUNKS-1))); do 63 | cat ./playground/data/eval/pope/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 64 | done 65 | 66 | python imp_llava/eval/eval_pope.py \ 67 | --annotation-dir ./playground/data/eval/pope/coco \ 68 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 69 | --result-file $output_file 70 | -------------------------------------------------------------------------------- /scripts/eval/sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_scienceqa" 14 | 15 | # merge eval 16 | MODEL_CKPT="milvlg/imp-v1-3b" 17 | # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | MODEL_PATH=$MODEL_CKPT 20 | # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | for IDX in $(seq 0 $((CHUNKS-1))); do 23 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_science \ 24 | --model-path $MODEL_PATH \ 25 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 26 | --image-folder ./playground/data/eval/scienceqa/images/test \ 27 | --answers-file ./playground/data/eval/scienceqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 28 | --num-chunks $CHUNKS \ 29 | --chunk-idx $IDX \ 30 | --temperature 0 \ 31 | --conv-mode phi2 & 32 | done 33 | 34 | wait 35 | 36 | # # lora eval 37 | # MODEL_CKPT="imp-v1-3b-stage2-lora" 38 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 39 | # MODEL_BASE=checkpoints/base/phi-2 40 | 41 | # for IDX in $(seq 0 $((CHUNKS-1))); do 42 | # CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_science \ 43 | # --model-path ./checkpoints/$MODEL_CKPT \ 44 | # --model-base $MODEL_BASE \ 45 | # --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 46 | # --image-folder ./playground/data/eval/scienceqa/images/test \ 47 | # --answers-file ./playground/data/eval/scienceqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 48 | # --num-chunks $CHUNKS \ 49 | # --chunk-idx $IDX \ 50 | # --temperature 0 \ 51 | # --conv-mode phi2 & 52 | # done 53 | 54 | # wait 55 | 56 | 57 | output_file=./playground/data/eval/scienceqa/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 58 | 59 | # Clear out the output file if it exists. 60 | > "$output_file" 61 | 62 | # Loop through the indices and concatenate each file. 63 | for IDX in $(seq 0 $((CHUNKS-1))); do 64 | cat ./playground/data/eval/scienceqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 65 | done 66 | 67 | 68 | python imp_llava/eval/eval_science_qa.py \ 69 | --base-dir ./playground/data/eval/scienceqa \ 70 | --result-file $output_file \ 71 | --output-file ./playground/data/eval/scienceqa/answers/output.jsonl \ 72 | --output-result ./playground/data/eval/scienceqa/answers/result.json 73 | -------------------------------------------------------------------------------- /scripts/eval/textvqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_textvqa_val" 14 | 15 | # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | MODEL_CKPT="imp-cuustom" # eval your own checkpoint 18 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | 23 | for IDX in $(seq 0 $((CHUNKS-1))); do 24 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 25 | --model-path /data/ouyangxc/github/imp/checkpoints/imp-qwen1.5-merged-cus1/ \ 26 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 27 | --image-folder ./playground/data/eval/textvqa/train_images \ 28 | --answers-file ./playground/data/eval/textvqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 29 | --num-chunks $CHUNKS \ 30 | --chunk-idx $IDX \ 31 | --temperature 0 \ 32 | --conv-mode qwen2 & 33 | done 34 | 35 | wait 36 | 37 | # # lora eval 38 | # MODEL_CKPT="imp-v1-3b-stage2-lora" 39 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 40 | # MODEL_BASE=checkpoints/base/phi-2 41 | 42 | # for IDX in $(seq 0 $((CHUNKS-1))); do 43 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 44 | # --model-path ./checkpoints/$MODEL_CKPT \ 45 | # --model-base $MODEL_BASE \ 46 | # --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 47 | # --image-folder ./playground/data/eval/textvqa/train_images \ 48 | # --answers-file ./playground/data/eval/textvqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 49 | # --num-chunks $CHUNKS \ 50 | # --chunk-idx $IDX \ 51 | # --temperature 0 \ 52 | # --conv-mode phi2 & 53 | # done 54 | 55 | # wait 56 | 57 | 58 | output_file=./playground/data/eval/textvqa/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 59 | 60 | # Clear out the output file if it exists. 61 | > "$output_file" 62 | 63 | # Loop through the indices and concatenate each file. 64 | for IDX in $(seq 0 $((CHUNKS-1))); do 65 | cat ./playground/data/eval/textvqa/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 66 | done 67 | 68 | 69 | 70 | python -m imp_llava.eval.eval_textvqa \ 71 | --annotation-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 72 | --result-file $output_file 73 | -------------------------------------------------------------------------------- /scripts/eval/vizwiz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 9 | IFS=',' read -ra GPULIST <<< "$gpu_list" 10 | 11 | CHUNKS=${#GPULIST[@]} 12 | 13 | SPLIT="llava_vizwiz" 14 | 15 | # # merge eval 16 | # MODEL_CKPT="milvlg/imp-v1-3b" 17 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 18 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 19 | # MODEL_PATH=$MODEL_CKPT 20 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 21 | 22 | 23 | # for IDX in $(seq 0 $((CHUNKS-1))); do 24 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 25 | # --model-path $MODEL_PATH \ 26 | # --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 27 | # --image-folder ./playground/data/eval/vizwiz/test \ 28 | # --answers-file ./playground/data/eval/vizwiz/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 29 | # --num-chunks $CHUNKS \ 30 | # --chunk-idx $IDX \ 31 | # --temperature 0 \ 32 | # --conv-mode phi2 & 33 | # done 34 | 35 | # wait 36 | 37 | # lora eval 38 | MODEL_CKPT="imp-v1-3b-stage2-lora" 39 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 40 | MODEL_BASE=checkpoints/base/phi-2 41 | 42 | for IDX in $(seq 0 $((CHUNKS-1))); do 43 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 44 | --model-path ./checkpoints/$MODEL_CKPT \ 45 | --model-base $MODEL_BASE \ 46 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 47 | --image-folder ./playground/data/eval/vizwiz/test \ 48 | --answers-file ./playground/data/eval/vizwiz/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 49 | --num-chunks $CHUNKS \ 50 | --chunk-idx $IDX \ 51 | --temperature 0 \ 52 | --conv-mode phi2 & 53 | done 54 | 55 | wait 56 | 57 | 58 | output_file=./playground/data/eval/vizwiz/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 59 | 60 | # Clear out the output file if it exists. 61 | > "$output_file" 62 | 63 | # Loop through the indices and concatenate each file. 64 | for IDX in $(seq 0 $((CHUNKS-1))); do 65 | cat ./playground/data/eval/vizwiz/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 66 | done 67 | 68 | python scripts/convert_vizwiz_for_submission.py \ 69 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 70 | --result-file $output_file \ 71 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/result.json 72 | -------------------------------------------------------------------------------- /scripts/eval/vqav2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # uncomment the following lines to shutoff the internet access 3 | # export HF_HUB_OFFLINE=True 4 | # export HF_DATASETS_OFFLINE=1 5 | # export TRANSFORMERS_OFFLINE=1 6 | export IMP_SILIENT_OTHERS=true 7 | 8 | # uncomment the following lines to shutoff the internet access 9 | 10 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 11 | IFS=',' read -ra GPULIST <<< "$gpu_list" 12 | 13 | CHUNKS=${#GPULIST[@]} 14 | 15 | SPLIT="llava_vqav2_mscoco_test-dev2015" 16 | 17 | # # merge eval 18 | # MODEL_CKPT="milvlg/imp-v1-3b" 19 | # # MODEL_CKPT="imp-v1-3b" # eval your own checkpoint 20 | # EVAL_CKPT="${MODEL_CKPT//\//_}_1" 21 | # MODEL_PATH=$MODEL_CKPT 22 | # # MODEL_PATH="./checkpoints/$MODEL_CKPT" # eval your own checkpoint 23 | 24 | # for IDX in $(seq 0 $((CHUNKS-1))); do 25 | # LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 26 | # --model-path $MODEL_PATH \ 27 | # --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 28 | # --image-folder ./playground/data/eval/vqav2/test2015 \ 29 | # --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 30 | # --num-chunks $CHUNKS \ 31 | # --chunk-idx $IDX \ 32 | # --temperature 0 \ 33 | # --conv-mode phi2 & 34 | # done 35 | 36 | # wait 37 | 38 | # lora eval 39 | MODEL_CKPT="imp-v1-3b-stage2-lora" 40 | EVAL_CKPT="${MODEL_CKPT//\//_}_1" 41 | MODEL_BASE=checkpoints/base/phi-2 42 | 43 | for IDX in $(seq 0 $((CHUNKS-1))); do 44 | LOCAL_RANK=$IDX CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m imp_llava.eval.model_vqa_loader \ 45 | --model-path ./checkpoints/$MODEL_CKPT \ 46 | --model-base $MODEL_BASE \ 47 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 48 | --image-folder ./playground/data/eval/vqav2/test2015 \ 49 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl \ 50 | --num-chunks $CHUNKS \ 51 | --chunk-idx $IDX \ 52 | --temperature 0 \ 53 | --conv-mode phi2 & 54 | done 55 | 56 | wait 57 | 58 | 59 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$EVAL_CKPT/merge.jsonl 60 | 61 | # Clear out the output file if it exists. 62 | > "$output_file" 63 | 64 | # Loop through the indices and concatenate each file. 65 | for IDX in $(seq 0 $((CHUNKS-1))); do 66 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$EVAL_CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 67 | done 68 | 69 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $EVAL_CKPT 70 | 71 | -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # uncomment the following lines to shutoff the internet access 4 | # export HF_HUB_OFFLINE=True 5 | # export HF_DATASETS_OFFLINE=1 6 | # export TRANSFORMERS_OFFLINE=1 7 | export IMP_SILIENT_OTHERS=true 8 | 9 | # if not use all GPUs 10 | # deepspeed --include localhost:0,1,2,3 --master_port 29600 11 | 12 | deepspeed imp_llava/train/train_mem.py \ 13 | --deepspeed ./scripts/zero3.json \ 14 | --model_name_or_path checkpoints/base/phi-2 \ 15 | --version phi2 \ 16 | --data_path datasets/llava_v1_5_mix665k.json \ 17 | --image_folder datasets/finetune_images \ 18 | --vision_tower checkpoints/base/siglip-so400m-patch14-384 \ 19 | --pretrain_mm_mlp_adapter ./checkpoints/imp-v1-3b-stage1/mm_projector.bin \ 20 | --mm_projector_type mlp2x_gelu \ 21 | --mm_vision_select_layer -2 \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --image_aspect_ratio square \ 25 | --group_by_modality_length True \ 26 | --bf16 False \ 27 | --fp16 True \ 28 | --output_dir ./checkpoints/imp-v1-3b-fft \ 29 | --num_train_epochs 2 \ 30 | --per_device_train_batch_size 4 \ 31 | --per_device_eval_batch_size 4 \ 32 | --gradient_accumulation_steps 4 \ 33 | --evaluation_strategy "no" \ 34 | --save_strategy "steps" \ 35 | --save_steps 50000 \ 36 | --save_total_limit 1 \ 37 | --learning_rate 2e-5 \ 38 | --weight_decay 0. \ 39 | --warmup_ratio 0.03 \ 40 | --lr_scheduler_type "cosine" \ 41 | --logging_steps 1 \ 42 | --tf32 False \ 43 | --model_max_length 3072 \ 44 | --gradient_checkpointing True \ 45 | --dataloader_num_workers 4 \ 46 | --lazy_preprocess True \ 47 | --report_to none 48 | -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # uncomment the following lines to shutoff the internet access 4 | # export HF_HUB_OFFLINE=True 5 | # export HF_DATASETS_OFFLINE=1 6 | # export TRANSFORMERS_OFFLINE=1 7 | export IMP_SILIENT_OTHERS=true 8 | 9 | # if not use all GPUs 10 | # deepspeed --include localhost:0,1,2,3 --master_port 29600 11 | 12 | deepspeed imp_llava/train/train_mem.py \ 13 | --lora_enable True --lora_r 256 --lora_alpha 256 --mm_projector_lr 2e-5 \ 14 | --deepspeed ./scripts/zero3.json \ 15 | --model_name_or_path checkpoints/base/phi-2 \ 16 | --version phi2 \ 17 | --data_path datasets/llava_v1_5_mix665k.json \ 18 | --image_folder datasets/finetune_images \ 19 | --vision_tower checkpoints/base/siglip-so400m-patch14-384 \ 20 | --pretrain_mm_mlp_adapter ./checkpoints/imp-v1-3b-stage1/mm_projector.bin \ 21 | --mm_projector_type mlp2x_gelu \ 22 | --mm_vision_select_layer -2 \ 23 | --mm_use_im_start_end False \ 24 | --mm_use_im_patch_token False \ 25 | --image_aspect_ratio square \ 26 | --group_by_modality_length True \ 27 | --bf16 False \ 28 | --fp16 True \ 29 | --output_dir ./checkpoints/imp-v1-3b-stage2-lora \ 30 | --num_train_epochs 2 \ 31 | --per_device_train_batch_size 4 \ 32 | --per_device_eval_batch_size 4 \ 33 | --gradient_accumulation_steps 4 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 50000 \ 37 | --save_total_limit 1 \ 38 | --learning_rate 2e-4 \ 39 | --weight_decay 0. \ 40 | --warmup_ratio 0.03 \ 41 | --lr_scheduler_type "cosine" \ 42 | --logging_steps 1 \ 43 | --tf32 False \ 44 | --model_max_length 3072 \ 45 | --gradient_checkpointing True \ 46 | --dataloader_num_workers 4 \ 47 | --lazy_preprocess True \ 48 | --report_to none 49 | -------------------------------------------------------------------------------- /scripts/finetune_lora_custom.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | while [[ "$#" -gt 0 ]]; do 4 | case $1 in 5 | -imp_model) IMP_MODEL="$2"; shift ;; 6 | -version) VERSION="$2"; shift ;; 7 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 8 | esac 9 | shift 10 | done 11 | # uncomment the following lines to shutoff the internet access 12 | # export HF_HUB_OFFLINE=True 13 | # export HF_DATASETS_OFFLINE=1 14 | # export TRANSFORMERS_OFFLINE=1 15 | export IMP_SILIENT_OTHERS=true 16 | 17 | # if not use all GPUs 18 | # deepspeed --include localhost:0,1,2,3 --master_port 29600 19 | 20 | deepspeed imp_llava/train/train_mem.py \ 21 | --lora_enable True --lora_r 256 --lora_alpha 256 --mm_projector_lr 2e-5 \ 22 | --deepspeed ./scripts/zero3.json \ 23 | --model_name_or_path $IMP_MODEL \ 24 | --version $VERSION \ 25 | --data_path /data/common_datasets/llava/llava_v1_5_mix665k.json \ 26 | --image_folder /data/common_datasets/llava/ft_datasets \ 27 | --vision_tower ./checkpoints/siglip-so400m-patch14-384 \ 28 | --mm_projector_type mlp2x_gelu \ 29 | --mm_vision_select_layer -2 \ 30 | --mm_use_im_start_end False \ 31 | --mm_use_im_patch_token False \ 32 | --image_aspect_ratio square \ 33 | --group_by_modality_length True \ 34 | --bf16 False \ 35 | --fp16 True \ 36 | --output_dir ./checkpoints/imp-${VERSION}-lora-custom \ 37 | --num_train_epochs 2 \ 38 | --per_device_train_batch_size 4 \ 39 | --per_device_eval_batch_size 4 \ 40 | --gradient_accumulation_steps 4 \ 41 | --gradient_checkpointing True\ 42 | --evaluation_strategy "no" \ 43 | --save_strategy "steps" \ 44 | --save_steps 50000 \ 45 | --save_total_limit 1 \ 46 | --learning_rate 2e-4 \ 47 | --weight_decay 0. \ 48 | --warmup_ratio 0.03 \ 49 | --lr_scheduler_type "cosine" \ 50 | --logging_steps 1 \ 51 | --tf32 False \ 52 | --model_max_length 3072 \ 53 | --gradient_checkpointing True \ 54 | --dataloader_num_workers 4 \ 55 | --lazy_preprocess True \ 56 | --report_to none 57 | -------------------------------------------------------------------------------- /scripts/merge.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | while [[ "$#" -gt 0 ]]; do 4 | case $1 in 5 | -imp_model) IMP_MODEL="$2"; shift ;; 6 | -version) VERSION="$2"; shift ;; 7 | -lora) MODEL_CKPT="$2"; shift ;; 8 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 9 | esac 10 | shift 11 | done 12 | 13 | python -m imp_llava.eval.model_merge \ 14 | --model-path $MODEL_CKPT \ 15 | --model-base $IMP_MODEL \ 16 | --save-name imp-${VERSION}-merged -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from imp_llava.model.builder import load_pretrained_model 3 | from imp_llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # uncomment the following lines to shutoff the internet access 4 | # export HF_HUB_OFFLINE=True 5 | # export HF_DATASETS_OFFLINE=1 6 | # export TRANSFORMERS_OFFLINE=1 7 | export IMP_SILIENT_OTHERS=true 8 | 9 | # if not use all GPUs 10 | # deepspeed --include localhost:0,1,2,3 --master_port 29600 11 | 12 | deepspeed imp_llava/train/train_mem.py \ 13 | --deepspeed ./scripts/zero2.json \ 14 | --model_name_or_path checkpoints/base/phi-2 \ 15 | --version plain \ 16 | --data_path datasets/blip_laion_cc_sbu_558k.json \ 17 | --image_folder datasets/pretrain_images/ \ 18 | --vision_tower checkpoints/base/siglip-so400m-patch14-384 \ 19 | --mm_projector_type mlp2x_gelu \ 20 | --tune_mm_mlp_adapter True \ 21 | --mm_vision_select_layer -2 \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --bf16 False \ 25 | --fp16 True \ 26 | --output_dir ./checkpoints/imp-v1-3b-stage1 \ 27 | --num_train_epochs 1 \ 28 | --per_device_train_batch_size 32 \ 29 | --per_device_eval_batch_size 4 \ 30 | --gradient_accumulation_steps 1 \ 31 | --evaluation_strategy "no" \ 32 | --save_strategy "steps" \ 33 | --save_steps 24000 \ 34 | --save_total_limit 1 \ 35 | --learning_rate 1e-3 \ 36 | --weight_decay 0. \ 37 | --warmup_ratio 0.03 \ 38 | --lr_scheduler_type "cosine" \ 39 | --logging_steps 1 \ 40 | --tf32 False \ 41 | --model_max_length 3072 \ 42 | --gradient_checkpointing True \ 43 | --dataloader_num_workers 4 \ 44 | --lazy_preprocess True \ 45 | --report_to none 46 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /tmp/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MILVLG/imp/363a9bb657f34ed7750f2bf2b88afcb45e3bd512/tmp/.gitkeep --------------------------------------------------------------------------------