├── LICENSE ├── README.md ├── cog.yaml ├── eval_mmmu ├── answer_dict_val.json ├── configs │ └── llava1.5.yaml ├── mmmu_only_eval.py ├── mmmu_response.py └── utils │ ├── data_utils.py │ ├── eval_utils.py │ └── model_utils.py ├── evaluation_mathvista ├── build_query.py ├── calculate_score.py ├── ext_ans.py ├── extract_answer.py ├── mathvista_data │ ├── annot_testmini.json │ ├── query.json │ └── testmini.json ├── response.py └── utilities.py ├── finetune_task.sh ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── __pycache__ │ │ └── run_llava.cpython-310.pyc │ └── run_llava.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── pipeline.png ├── pyproject.toml ├── scripts ├── extract_mm_projector.py ├── merge_lora_weights.py ├── zero2.json ├── zero3.json └── zero3_offload.json └── train_samples_all_tuning.json.zip /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Math-LLaVA 2 | 3 | This repository contains the code, data and model for the paper titled "Math-LLaVA: Bootstrapping Mathematical Reasoning for Multimodal Large Language Models". 4 | 5 | [Paper](http://arxiv.org/abs/2406.17294v2), [Image Dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main), [Model](https://huggingface.co/Zhiqiang007/Math-LLaVA/tree/main) 6 | 7 | ![ex1](pipeline.png) 8 | 9 | ## Latest News 🔥 10 | * [2023-06-26] We released [Math-LLaVA checkpoints](https://huggingface.co/Zhiqiang007/Math-LLaVA/tree/main). The Math-LLaVA-13B model achieves **46.6%** on MathVista testmini, achieves **38.3%** on MMMU, and achieves **15.69%** on MATH-V. 11 | * [2024-06-25] Release [paper](http://arxiv.org/abs/2406.17294v2), [code](https://github.com/HZQ950419/Math-LLaVA) and [MathV360K dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main). 12 | 13 | ## Install Packages 14 | ``` 15 | cd Math-LLaVA 16 | conda create -n math_llava python=3.10 -y 17 | conda activate math_llava 18 | pip install -e . 19 | ``` 20 | ## Enable Deepspeed and Flash-attention 21 | ``` 22 | pip install -e ".[train]" 23 | pip install flash-attn --no-build-isolation 24 | ``` 25 | 26 | ## Data Preparation 27 | "train_samples_all_tuning.json" corresponds to the annotations of qa pairs for finetuning. 28 | Download [image dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main). 29 | 30 | Place the data in the root directory or other directory. 31 | Data structure: 32 | ``` 33 | ├── data_images/ 34 | │ ├── TabMWP/images/ 35 | │ ├── IconQA/images/ 36 | │ ├── ... 37 | ├── train_samples_all_tuning.json 38 | ``` 39 | 40 | ## Run full-finetuning 41 | ``` 42 | sh finetune_task.sh 43 | ``` 44 | 45 | ## MathVista Evaluation 46 | You can download and unzip images of MathVista using the following commands: 47 | ``` 48 | cd ./evaluation_mathvista/mathvista_data 49 | wget https://huggingface.co/datasets/AI4Math/MathVista/resolve/main/images.zip 50 | unzip images.zip 51 | ``` 52 | Generate the response on testmini subset: 53 | ``` 54 | cd evaluation_mathvista 55 | python response.py --output_dir ./mathvista_outputs --output_file responses.json --model_path your/model/path --model_base None 56 | ``` 57 | Extract the short answer text for score calculation by ChatGPT. Please refer [OpenAI API key](https://platform.openai.com/account/api-keys). 58 | ``` 59 | python extract_answer.py --output_file responses.json 60 | ``` 61 | Calculate the final score: 62 | ``` 63 | python calculate_score.py --output_file responses.json --score_file responses_score.json 64 | ``` 65 | 66 | ## MMMU Evaluation 67 | Generate the response: 68 | ``` 69 | cd eval_mmmu 70 | python mmmu_response.py --output_path mmmu_eval_output.json --model_path 71 | ``` 72 | Calculate the score: 73 | ``` 74 | python mmmu_only_eval.py --output_path mmmu_eval_output.json --answer_path ./answer_dict_val.json 75 | ``` 76 | ## Results on MathVista 77 | Accuracy scores on the testmini subset: 78 | 79 | | Model | ALL | FQA |GPS |MWP |TQA |VQA | 80 | |-----------------------|--------|--------|--------|--------|--------|--------| 81 | | miniGPT4-7B |**23.1**|**18.6**|**26.0**|**13.4**|**30.4**|**30.2**| 82 | | InstructBLIP-7B |**25.3**|**23.1**|**20.7**|**18.3**|**32.3**|**35.2**| 83 | | LLaVA-13B |**26.1**|**26.8**|**29.3**|**16.1**|**32.3**|**26.3**| 84 | | SPHINX-V1-13B |**27.5**|**23.4**|**23.1**|**21.5**|**39.9**|**34.1**| 85 | | LLaVA-1.5-13B |**27.6**|**-**|**-**|**-**|**-**|**-**| 86 | | OmniLMM-12B |**34.9**|**45.0**|**17.8**|**26.9**|**44.9**|**39.1**| 87 | | Math-LLaVA-13B |**46.6**|**37.2**|**57.7**|**56.5**|**51.3**|**33.5**| 88 | 89 | 90 | 91 | ## Results on MMMU 92 | Accuracy scores on the validation set: 93 | 94 | | Model | ALL | 95 | |-----------------------|--------| 96 | | miniGPT4-7B |**26.8**| 97 | | mPLUG-Owl-7B |**32.7**| 98 | | InstructBLIP-7B |**32.9**| 99 | | SPHINX-13B |**32.9**| 100 | | LLaVA-1.5-13B |**36.4**| 101 | | Math-LLaVA-13B |**38.3**| 102 | 103 | ## Results on MATH-V 104 | We also test on [MATH-V](https://github.com/mathvision-cuhk/MATH-V), a more challenging dataset: 105 | 106 | | Model | ALL | 107 | |-----------------------|--------| 108 | | Qwen-VL-Plus |**10.72**| 109 | | LLaVA-1.5-13B |**11.12**| 110 | | ShareGPT4V-13B |**11.88**| 111 | | InternLM-XComposer2-VL|**14.54**| 112 | | Math-LLaVA-13B |**15.69**| 113 | 114 | ## Acknowledgement 115 | The project is built on top of the amazing [LLaVA](https://github.com/haotian-liu/LLaVA) repository, [MathVista](https://github.com/lupantech/MathVista) and [MMMU](https://github.com/MMMU-Benchmark/MMMU). Thanks for their contributions! 116 | 117 | 118 | If you find our code and dataset helpful to your research, please consider citing us with this BibTeX: 119 | ```bibtex 120 | @misc{shihu2024mathllava, 121 | title={Math-LLaVA: Bootstrapping Mathematical Reasoning for Multimodal Large Language Models}, 122 | author={Wenhao Shi and Zhiqiang Hu and Yi Bin and Junhua Liu and Yang Yang and See-Kiong Ng and Lidong Bing and Roy Ka-Wei Lee}, 123 | year={2024}, 124 | eprint={2406.17294}, 125 | archivePrefix={arXiv}, 126 | primaryClass={cs.CL} 127 | } 128 | ``` 129 | 130 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /eval_mmmu/configs/llava1.5.yaml: -------------------------------------------------------------------------------- 1 | #task_instructions: 2 | #- "" 3 | #multi_choice_example_format: 4 | #- "{} 5 | # 6 | #{} 7 | # 8 | #Answer with the option's letter from the given choices directly." 9 | # 10 | #short_ans_example_format: 11 | #- "{} 12 | # 13 | #Answer the question using a single word or phrase." 14 | #temperature: 15 | #- 0 16 | task_instructions: 17 | - "" 18 | multi_choice_example_format: 19 | - "Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. 20 | 21 | {} 22 | 23 | {}" 24 | 25 | short_ans_example_format: 26 | - "Hint: Please answer the question and provide the final answer at the end. 27 | 28 | {}" 29 | temperature: 30 | - 0 -------------------------------------------------------------------------------- /eval_mmmu/mmmu_only_eval.py: -------------------------------------------------------------------------------- 1 | """Parse and Evalate""" 2 | import os 3 | import json 4 | 5 | import pdb 6 | from argparse import ArgumentParser 7 | from utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT 8 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc, eval_open 9 | 10 | if __name__ == '__main__': 11 | 12 | parser = ArgumentParser() 13 | parser.add_argument('--output_path', type=str, default="mmmu_eval_output.json", 14 | help="The path to model output file.") 15 | parser.add_argument('--answer_path', type=str, default="./answer_dict_val.json", help="Answer file path.") 16 | args = parser.parse_args() 17 | 18 | output_dict = json.load(open(args.output_path)) 19 | answer_dict = json.load(open(args.answer_path)) 20 | 21 | # group by category 22 | output_dict_w_cat = {} 23 | for data_id, parsed_pred in output_dict.items(): 24 | category = "_".join(data_id.split("_")[1:-1]) 25 | if category not in output_dict_w_cat: 26 | output_dict_w_cat.update({category: {}}) 27 | output_dict_w_cat[category].update({data_id: parsed_pred}) 28 | 29 | # group by category 30 | answer_dict_w_cat = {} 31 | for data_id, parsed_pred in answer_dict.items(): 32 | category = "_".join(data_id.split("_")[1:-1]) 33 | if category not in answer_dict_w_cat: 34 | answer_dict_w_cat.update({category: {}}) 35 | answer_dict_w_cat[category].update({data_id: parsed_pred}) 36 | 37 | evaluation_result = {} 38 | 39 | for category in CAT_SHORT2LONG.values(): 40 | print("Evaluating: {}".format(category)) 41 | # get cat_outputs and cat_answers 42 | try: 43 | cat_outputs = output_dict_w_cat[category] 44 | cat_answers = answer_dict_w_cat[category] 45 | except KeyError: 46 | print("Skipping {} for not found".format(category)) 47 | continue 48 | 49 | exampels_to_eval = [] 50 | for data_id, parsed_pred in cat_outputs.items(): 51 | question_type = cat_answers[data_id]['question_type'] 52 | if question_type != 'multiple-choice': 53 | parsed_pred = parse_open_response(parsed_pred) 54 | # print(parsed_pred) 55 | # print(cat_answers[data_id]['ground_truth']) 56 | correct = eval_open(cat_answers[data_id]['ground_truth'], parsed_pred) 57 | else: 58 | parsed_pred = parsed_pred 59 | 60 | 61 | exampels_to_eval.append({ 62 | "id": data_id, 63 | "question_type": question_type, 64 | "answer": cat_answers[data_id]['ground_truth'], 65 | "parsed_pred": parsed_pred 66 | }) 67 | 68 | judge_dict, metric_dict = evaluate(exampels_to_eval) 69 | metric_dict.update({"num_example": len(exampels_to_eval)}) 70 | 71 | evaluation_result[category] = metric_dict 72 | 73 | printable_results = {} 74 | # pdb.set_trace() 75 | # add domain Subject 76 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items(): 77 | in_domain_cat_results = {} 78 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT 79 | if cat_name in evaluation_result.keys(): 80 | in_domain_cat_results[cat_name] = evaluation_result[cat_name] 81 | else: 82 | pass 83 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results) 84 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()]) 85 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num), 86 | "acc": round(in_domain_ins_acc, 3) 87 | } 88 | # add sub category 89 | for cat_name, cat_results in in_domain_cat_results.items(): 90 | printable_results[cat_name] = {"num": int(cat_results['num_example']), 91 | "acc": round(cat_results['acc'], 3) 92 | } 93 | 94 | # table.append(["-----------------------------", "-----", "----"]) 95 | all_ins_acc = calculate_ins_level_acc(evaluation_result) 96 | printable_results['Overall'] = { 97 | "num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]), 98 | "acc": round(all_ins_acc, 3) 99 | } 100 | 101 | print(printable_results) 102 | -------------------------------------------------------------------------------- /eval_mmmu/mmmu_response.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import random 4 | 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from datasets import load_dataset, concatenate_datasets 9 | from llava.model.builder import load_pretrained_model 10 | from llava.mm_utils import get_model_name_from_path 11 | 12 | from argparse import ArgumentParser 13 | 14 | from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG 15 | from utils.model_utils import call_llava_engine_df, llava_image_processor 16 | from utils.eval_utils import parse_multi_choice_response, parse_open_response 17 | 18 | 19 | def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None): 20 | out_samples = dict() 21 | with torch.no_grad(): 22 | for sample in tqdm(samples): 23 | response = call_model_engine_fn(args, sample, model, tokenizer, processor) 24 | #print(response) 25 | if sample['question_type'] == 'multiple-choice': 26 | pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans']) 27 | else: # open question 28 | pred_ans = response 29 | out_samples[sample['id']] = pred_ans 30 | return out_samples 31 | 32 | def set_seed(seed_value): 33 | """ 34 | Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results. 35 | 36 | :param seed_value: An integer value to be used as the seed. 37 | """ 38 | torch.manual_seed(seed_value) 39 | if torch.cuda.is_available(): 40 | torch.cuda.manual_seed(seed_value) 41 | torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups 42 | random.seed(seed_value) 43 | np.random.seed(seed_value) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | 47 | def main(): 48 | parser = ArgumentParser() 49 | parser.add_argument('--output_path', type=str, default='mmmu_eval_output.json', 50 | help='name of saved json') 51 | parser.add_argument('--config_path', type=str, default="configs/llava1.5.yaml") 52 | parser.add_argument('--data_path', type=str, default="MMMU/MMMU") # hf dataset path. 53 | parser.add_argument('--model_path', type=str, default="") 54 | parser.add_argument('--split', type=str, default='validation') 55 | parser.add_argument('--seed', type=int, default=42) 56 | 57 | args = parser.parse_args() 58 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 59 | set_seed(args.seed) 60 | 61 | print('llava_initializing...') 62 | processor = None 63 | call_model_engine = call_llava_engine_df 64 | vis_process_func = llava_image_processor 65 | 66 | # load config and process to one value 67 | args.config = load_yaml(args.config_path) 68 | for key, value in args.config.items(): 69 | if key != 'eval_params' and type(value) == list: 70 | assert len(value) == 1, 'key {} has more than one value'.format(key) 71 | args.config[key] = value[0] 72 | 73 | # run for each subject 74 | sub_dataset_list = [] 75 | for subject in CAT_SHORT2LONG.values(): 76 | sub_dataset = load_dataset(args.data_path, subject, split=args.split) 77 | sub_dataset_list.append(sub_dataset) 78 | 79 | # merge all dataset 80 | dataset = concatenate_datasets(sub_dataset_list) 81 | 82 | 83 | # load model 84 | model_name = get_model_name_from_path(args.model_path) 85 | 86 | tokenizer, model, vis_processors, _ = load_pretrained_model(model_path=args.model_path, 87 | model_base=None, 88 | model_name=model_name) 89 | 90 | samples = [] 91 | for sample in dataset: 92 | sample = process_single_sample(sample) 93 | #print(sample) 94 | sample = construct_prompt(sample, args.config) 95 | if sample['image']: 96 | sample['image'] = vis_process_func(sample['image'], vis_processors).to(device) 97 | samples.append(sample) 98 | 99 | ## run ex 100 | out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor) 101 | 102 | save_json(args.output_path, out_samples) 103 | # metric_dict.update({"num_example": len(out_samples)}) 104 | # save_json(save_result_path, metric_dict) 105 | 106 | 107 | if __name__ == '__main__': 108 | main() 109 | -------------------------------------------------------------------------------- /eval_mmmu/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for data load, save, and process (e.g., prompt construction)""" 2 | 3 | import os 4 | import json 5 | import yaml 6 | import re 7 | 8 | 9 | DOMAIN_CAT2SUB_CAT = { 10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'], 11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage','Marketing'], 12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics',], 13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health'], 14 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'], 15 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'], 16 | } 17 | 18 | 19 | CAT_SHORT2LONG = { 20 | 'acc': 'Accounting', 21 | 'agri': 'Agriculture', 22 | 'arch': 'Architecture_and_Engineering', 23 | 'art': 'Art', 24 | 'art_theory': 'Art_Theory', 25 | 'bas_med': 'Basic_Medical_Science', 26 | 'bio': 'Biology', 27 | 'chem': 'Chemistry', 28 | 'cli_med': 'Clinical_Medicine', 29 | 'cs': 'Computer_Science', 30 | 'design': 'Design', 31 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine', 32 | 'econ': 'Economics', 33 | 'elec': 'Electronics', 34 | 'ep': 'Energy_and_Power', 35 | 'fin': 'Finance', 36 | 'geo': 'Geography', 37 | 'his': 'History', 38 | 'liter': 'Literature', 39 | 'manage': 'Manage', 40 | 'mark': 'Marketing', 41 | 'mate': 'Materials', 42 | 'math': 'Math', 43 | 'mech': 'Mechanical_Engineering', 44 | 'music': 'Music', 45 | 'phar': 'Pharmacy', 46 | 'phys': 'Physics', 47 | 'psy': 'Psychology', 48 | 'pub_health': 'Public_Health', 49 | 'socio': 'Sociology' 50 | } 51 | 52 | # DATA SAVING 53 | def save_json(filename, ds): 54 | with open(filename, 'w') as f: 55 | json.dump(ds, f, indent=4) 56 | 57 | 58 | def get_multi_choice_info(options): 59 | """ 60 | Given the list of options for multiple choice question 61 | Return the index2ans and all_choices 62 | """ 63 | 64 | start_chr = 'A' 65 | all_choices = [] 66 | index2ans = {} 67 | for i, option in enumerate(options): 68 | index2ans[chr(ord(start_chr) + i)] = option 69 | all_choices.append(chr(ord(start_chr) + i)) 70 | 71 | return index2ans, all_choices 72 | 73 | def load_yaml(file_path): 74 | with open(file_path, 'r') as stream: 75 | try: 76 | yaml_dict = yaml.safe_load(stream) 77 | except yaml.YAMLError as exc: 78 | print(exc) 79 | 80 | return yaml_dict 81 | 82 | 83 | def parse_img_path(text): 84 | matches = re.findall("", text) 85 | return matches 86 | 87 | def process_single_sample(data): 88 | question = data['question'] 89 | o_imgs_paths = [] 90 | for option in data['options']: 91 | current_o_imgs_paths = parse_img_path(option) 92 | for img_path in current_o_imgs_paths: 93 | o_imgs_paths.append(img_path) 94 | 95 | if len(o_imgs_paths) > 1: # multiple images in options, used for random selection 96 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'], 97 | 'image': None, 'question_type': data['question_type']} 98 | else: 99 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'], 100 | 'image': data['image_1'], 'question_type': data['question_type']} 101 | 102 | 103 | # DATA SAVING 104 | def save_json(filename, ds): 105 | with open(filename, 'w') as f: 106 | json.dump(ds, f, indent=4) 107 | 108 | def save_jsonl(filename, data): 109 | """ 110 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value. 111 | 112 | Args: 113 | filename (str): The path to the file where the data should be saved. 114 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption. 115 | """ 116 | with open(filename, 'w', encoding='utf-8') as f: 117 | for img_path, caption in data.items(): 118 | # Extract the base filename without the extension 119 | base_filename = os.path.basename(img_path) 120 | # Create a JSON object with the filename as the key and caption as the value 121 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False) 122 | # Write the JSON object to the file, one per line 123 | f.write(json_record + '\n') 124 | 125 | def save_args(args, path_dir): 126 | argsDict = args.__dict__ 127 | with open(path_dir + 'setting.txt', 'w') as f: 128 | f.writelines('------------------ start ------------------' + '\n') 129 | for eachArg, value in argsDict.items(): 130 | f.writelines(eachArg + ' : ' + str(value) + '\n') 131 | f.writelines('------------------- end -------------------') 132 | 133 | 134 | 135 | # DATA PROCESSING 136 | def construct_prompt(sample, config): 137 | question = sample['question'] 138 | options = eval(sample['options']) 139 | example = "" 140 | if sample['question_type'] == 'multiple-choice': 141 | start_chr = 'A' 142 | prediction_range = [] 143 | index2ans = {} 144 | for option in options: 145 | prediction_range.append(start_chr) 146 | example += f"({start_chr}) {option}\n" 147 | index2ans[start_chr] = option 148 | start_chr = chr(ord(start_chr) + 1) 149 | empty_prompt_sample_structure = config['multi_choice_example_format'] 150 | empty_prompt = empty_prompt_sample_structure.format(question, example) 151 | res_dict = {} 152 | res_dict['index2ans'] = index2ans 153 | res_dict['correct_choice'] = sample['answer'] 154 | res_dict['all_choices'] = prediction_range 155 | res_dict['empty_prompt'] = empty_prompt 156 | if config['task_instructions']: 157 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 158 | else: 159 | res_dict['final_input_prompt'] = empty_prompt 160 | 161 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')] 162 | else: 163 | empty_prompt_sample_structure = config['short_ans_example_format'] 164 | empty_prompt = empty_prompt_sample_structure.format(question) 165 | res_dict = {} 166 | res_dict['empty_prompt'] = empty_prompt 167 | if config['task_instructions']: 168 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt 169 | else: 170 | res_dict['final_input_prompt'] = empty_prompt 171 | res_dict['gt_content'] = sample['answer'] 172 | 173 | res_dict.update(sample) 174 | return res_dict -------------------------------------------------------------------------------- /eval_mmmu/utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | """Response Parsing and Evaluation for various models""" 2 | from typing import Dict 3 | 4 | import re 5 | import random 6 | random.seed(42) 7 | import numpy as np 8 | 9 | # ----------- Process Multi-choice ------------- 10 | def parse_multi_choice_response(response, all_choices, index2ans): 11 | """ 12 | Parse the prediction from the generated response. 13 | Return the predicted index e.g., A, B, C, D. 14 | """ 15 | for char in [',', '.', '!', '?', ';', ':', "'"]: 16 | response = response.strip(char) 17 | response = " " + response + " " # add space to avoid partial match 18 | 19 | index_ans = True 20 | ans_with_brack = False 21 | candidates = [] 22 | for choice in all_choices: # e.g., (A) (B) (C) (D) 23 | if f'({choice})' in response: 24 | candidates.append(choice) 25 | ans_with_brack = True 26 | 27 | if len(candidates) == 0: 28 | for choice in all_choices: # e.g., A B C D 29 | if f' {choice} ' in response: 30 | candidates.append(choice) 31 | 32 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 33 | if len(candidates) == 0 and len(response.split()) > 5: 34 | for index, ans in index2ans.items(): 35 | if ans.lower() in response.lower(): 36 | candidates.append(index) 37 | index_ans = False # it's content ans. 38 | 39 | if len(candidates) == 0: # still not get answer, randomly choose one. 40 | pred_index = random.choice(all_choices) 41 | elif len(candidates) > 1: 42 | start_indexes = [] 43 | if index_ans: 44 | if ans_with_brack: 45 | for can in candidates: 46 | index = response.rfind(f'({can})') 47 | start_indexes.append(index) # -1 will be ignored anyway 48 | # start_indexes = [generated_response.index(f'({can})') for can in candidates] 49 | else: 50 | for can in candidates: 51 | index = response.rfind(f" {can} ") 52 | start_indexes.append(index) 53 | else: 54 | for can in candidates: 55 | index = response.lower().rfind(index2ans[can].lower()) 56 | start_indexes.append(index) 57 | # get the last one 58 | pred_index = candidates[np.argmax(start_indexes)] 59 | else: # if only one candidate, use it. 60 | pred_index = candidates[0] 61 | 62 | return pred_index 63 | 64 | # ----------- Process Open ------------- 65 | def check_is_number(string): 66 | """ 67 | Check if the given string a number. 68 | """ 69 | try: 70 | float(string.replace(',', '')) 71 | return True 72 | except ValueError: 73 | # check if there's comma inside 74 | return False 75 | 76 | def normalize_str(string): 77 | """ 78 | Normalize the str to lower case and make them float numbers if possible. 79 | """ 80 | # check if characters in the string 81 | 82 | # if number, numerize it. 83 | string = string.strip() 84 | 85 | is_number = check_is_number(string) 86 | 87 | if is_number: 88 | string = string.replace(',', '') 89 | string = float(string) 90 | # leave 2 decimal 91 | string = round(string, 2) 92 | return [string] 93 | else: # it's likely to be a string 94 | # lower it 95 | string = string.lower() 96 | if len(string) == 1: 97 | return [" " + string, string + " "] # avoid trivial matches 98 | return [string] 99 | 100 | def extract_numbers(string): 101 | """ 102 | Exact all forms of numbers from a string with regex. 103 | """ 104 | # Pattern for numbers with commas 105 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b' 106 | # Pattern for scientific notation 107 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+' 108 | # Pattern for simple numbers without commas 109 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])' 110 | 111 | # Extract numbers with commas 112 | numbers_with_commas = re.findall(pattern_commas, string) 113 | # Extract numbers in scientific notation 114 | numbers_scientific = re.findall(pattern_scientific, string) 115 | # Extract simple numbers without commas 116 | numbers_simple = re.findall(pattern_simple, string) 117 | 118 | # Combine all extracted numbers 119 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple 120 | return all_numbers 121 | 122 | def parse_open_response(response): 123 | """ 124 | Parse the prediction from the generated response. 125 | Return a list of predicted strings or numbers. 126 | """ 127 | # content = content.strip("\n").strip(".").strip(" ") 128 | def get_key_subresponses(response): 129 | key_responses = [] 130 | response = response.strip().strip(".").lower() 131 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response) 132 | indicators_of_keys = ['could be ', 'so ', 'is ', 133 | 'thus ', 'therefore ', 'final ', 'answer ', 'result '] 134 | key_responses = [] 135 | for index, resp in enumerate(sub_responses): 136 | # if last one, accept it's an equation (the entire response can be just one sentence with equation) 137 | if index == len(sub_responses) - 1: 138 | indicators_of_keys.extend(['=']) 139 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) 140 | for indicator in indicators_of_keys: 141 | if indicator in resp: 142 | if not shortest_key_response: 143 | shortest_key_response = resp.split(indicator)[-1].strip() 144 | else: 145 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): 146 | shortest_key_response = resp.split(indicator)[-1].strip() 147 | # key_responses.append(resp.split(indicator)[1].strip()) 148 | 149 | if shortest_key_response: 150 | # and it's not trivial 151 | if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: 152 | key_responses.append(shortest_key_response) 153 | if len(key_responses) == 0: # did not found any 154 | return [response] 155 | return key_responses 156 | # pdb.set_trace() 157 | key_responses = get_key_subresponses(response) 158 | 159 | pred_list = key_responses.copy() # keep the original string response 160 | for resp in key_responses: 161 | pred_list.extend(extract_numbers(resp)) 162 | 163 | tmp_pred_list = [] 164 | for i in range(len(pred_list)): 165 | tmp_pred_list.extend(normalize_str(pred_list[i])) 166 | pred_list = tmp_pred_list 167 | 168 | # remove duplicates 169 | pred_list = list(set(pred_list)) 170 | 171 | return pred_list 172 | 173 | # ----------- Evaluation ------------- 174 | 175 | def eval_multi_choice(gold_i, pred_i): 176 | """ 177 | Evaluate a multiple choice instance. 178 | """ 179 | correct = False 180 | # only they are exactly the same, we consider it as correct 181 | if isinstance(gold_i, list): 182 | for answer in gold_i: 183 | if answer == pred_i: 184 | correct = True 185 | break 186 | else: # gold_i is a string 187 | if gold_i == pred_i: 188 | correct = True 189 | return correct 190 | 191 | def eval_open(gold_i, pred_i): 192 | """ 193 | Evaluate an open question instance 194 | """ 195 | correct = False 196 | if isinstance(gold_i, list): 197 | # use float to avoid trivial matches 198 | norm_answers = [] 199 | for answer in gold_i: 200 | norm_answers.extend(normalize_str(answer)) 201 | else: 202 | norm_answers = normalize_str(gold_i) 203 | for pred in pred_i: # pred is already normalized in parse response phase 204 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i 205 | for norm_ans in norm_answers: 206 | # only see if the string answer in the string pred 207 | if isinstance(norm_ans, str) and norm_ans in pred: 208 | if not correct: 209 | correct = True 210 | break 211 | else: # it's a float number 212 | if pred in norm_answers: 213 | if not correct: 214 | correct = True 215 | break 216 | return correct 217 | 218 | # ----------- Batch Evaluation ------------- 219 | def evaluate(samples): 220 | """ 221 | Batch evaluation for multiple choice and open questions. 222 | """ 223 | pred_correct = 0 224 | judge_dict = dict() 225 | for sample in samples: 226 | gold_i = sample['answer'] 227 | pred_i = sample['parsed_pred'] 228 | if sample['question_type'] == 'multiple-choice': 229 | correct = eval_multi_choice(gold_i, pred_i) 230 | else: # open question 231 | correct = eval_open(gold_i, pred_i) 232 | 233 | if correct: 234 | judge_dict[sample['id']] = 'Correct' 235 | pred_correct += 1 236 | else: 237 | judge_dict[sample['id']] = 'Wrong' 238 | 239 | if len(samples) == 0: 240 | return {'acc': 0} 241 | return judge_dict, {'acc': pred_correct / len(samples)} 242 | 243 | 244 | 245 | # ----------- Calculate Accuracy ------------- 246 | def calculate_ins_level_acc(results: Dict): 247 | """Calculate the instruction level accuracy for given Subject results""" 248 | acc = 0 249 | ins_num = 0 250 | for cat_results in results.values(): 251 | acc += cat_results['acc'] * cat_results['num_example'] 252 | ins_num += cat_results['num_example'] 253 | if ins_num == 0: 254 | return 0 255 | return acc / ins_num 256 | 257 | -------------------------------------------------------------------------------- /eval_mmmu/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from random import random 2 | import torch 3 | 4 | def call_llava_engine_df(args, sample, model, tokenizer=None, processor=None): 5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 6 | from llava.conversation import conv_templates, SeparatorStyle 7 | 8 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 9 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 10 | 11 | def insert_separator(X, sep): 12 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 13 | 14 | input_ids = [] 15 | offset = 0 16 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 17 | offset = 1 18 | input_ids.append(prompt_chunks[0][0]) 19 | 20 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 21 | input_ids.extend(x[offset:]) 22 | 23 | if return_tensors is not None: 24 | if return_tensors == 'pt': 25 | return torch.tensor(input_ids, dtype=torch.long) 26 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 27 | return input_ids 28 | 29 | def deal_with_prompt(input_text, mm_use_im_start_end): 30 | qs = input_text 31 | if mm_use_im_start_end: 32 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 33 | else: 34 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 35 | return qs 36 | 37 | prompt = sample['final_input_prompt'] 38 | prompt = deal_with_prompt(prompt, model.config.mm_use_im_start_end) 39 | conv = conv_templates['vicuna_v1'].copy() 40 | conv.append_message(conv.roles[0], prompt) 41 | conv.append_message(conv.roles[1], None) 42 | prompt = conv.get_prompt() 43 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 44 | image = sample['image'] 45 | if image is not None: 46 | output_ids = model.generate( 47 | input_ids, 48 | images=image.unsqueeze(0).half().cuda(), 49 | do_sample=True, 50 | temperature=1, 51 | top_p=None, 52 | num_beams=5, 53 | max_new_tokens=128, 54 | use_cache=True) 55 | 56 | input_token_len = input_ids.shape[1] 57 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 58 | if n_diff_input_output > 0: 59 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 60 | response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 61 | else: # multiple images actually 62 | if sample['question_type'] == 'multiple-choice': 63 | all_choices = sample['all_choices'] 64 | response = random.choice(all_choices) 65 | else: 66 | response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS' 67 | 68 | return response 69 | 70 | 71 | def llava_image_processor(raw_image, vis_processors=None): 72 | image_tensor = vis_processors.preprocess(raw_image, return_tensors='pt')['pixel_values'][0] 73 | return image_tensor 74 | -------------------------------------------------------------------------------- /evaluation_mathvista/build_query.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # pids: 799, 681, 615 4 | shot_examples = [ 5 | { 6 | "question": "How much money does Ruth need to buy a baking dish, a casserole dish, and an ice cream scoop? (Unit: $)", 7 | "caption": "The image shows a table with a variety of items on it, including a baking dish, ice cream scoop, casserole dish, and rolling pin. The text in the image says:\n\n```\nbaking dish\n$4.00\nice cream scoop\n$6.00\ncasserole dish\n$3.00\nrolling pin\n$4.00\n```", 8 | "ocr": "[([5, 3], 'baking dish'), ([177, 5], '$4.00'), ([7, 41], 'ice cream scoop'), ([177, 37], '$6.00'), ([9, 69], 'casserole dish'), ([177, 69], '$3.00'), ([5, 98], 'rolling pin'), ([177, 101], '$4.00')]", 9 | "solution": """ 10 | Find the total cost of a baking dish, a casserole dish, and an ice cream scoop.\n\n$4.00 + $3.00 + $6.00 = $13.00\n\nRuth needs $13.00. 11 | """, 12 | "code": """ 13 | baking_dish_price = 4.00 14 | casserole_dish_price = 3.00 15 | ice_cream_scoop_price = 6.00 16 | 17 | ans = baking_dish_price + casserole_dish_price + ice_cream_scoop_price 18 | print(ans) 19 | """ 20 | }, 21 | 22 | { 23 | "question": "What is the largest city in the nation where this plane is headquartered?", 24 | "choices": ['hong kong', 'osaka', 'shanghai', 'tokyo'], 25 | "caption": "The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says \"Japan. Endless Discovery\".", 26 | "solution": """ 27 | The caption mentions that the text on the image says "Japan. Endless Discovery". This indicates that the plane is headquartered in Japan. 28 | 29 | Among the Japanese cities, Tokyo is the largest city. 30 | 31 | Thus, the answer is D (tokyo). 32 | """, 33 | "code": """ 34 | def largest_city(caption, choices): 35 | countries_largest_cities = { 36 | 'Japan': 'tokyo', 37 | 'China': 'shanghai' 38 | } 39 | 40 | if "Japan" in caption: 41 | country = 'Japan' 42 | elif "China" in caption: 43 | country = 'China' 44 | 45 | for choice in choices: 46 | if choice == countries_largest_cities[country]: 47 | return choice 48 | return "" 49 | 50 | choices = ['hong kong', 'osaka', 'shanghai', 'tokyo'] 51 | caption = "The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says 'Japan. Endless Discovery'." 52 | 53 | print(largest_city(caption, choices)) 54 | """ 55 | }, 56 | 57 | { 58 | "question": "If two sides of a triangle measure 12 and 7, which of the following cannot be the perimeter of the triangle?", 59 | "choices": ['29', '34', '37', '38'], 60 | "caption": "The image shows a triangle with two sides labeled 7 and 12. The triangle is drawn on a white background. There is no text other than the labels.", 61 | "ocr": "[([70, 74], '7'), ([324, 74], '12')]", 62 | "solution": """ 63 | To determine which of the given perimeters cannot be possible for the triangle, we apply the triangle inequality theorem. The sum of any two sides of a triangle must be greater than the third side. 64 | 65 | For the maximum possible value of the third side: 66 | 12 + 7 = 19 67 | 68 | The minimum possible value for the third side: 69 | 12 - 7 = 5 70 | 71 | The third side for each option: 72 | (A) 29 - 12 - 7 = 10 (valid) 73 | (B) 34 - 12 - 7 = 15 (valid) 74 | (C) 37 - 12 - 7 = 18 (valid) 75 | (D) 38 - 12 - 7 = 19 (invalid because it should be less than 19) 76 | 77 | Thus, the answer is D. 78 | """, 79 | "code": """ 80 | def is_valid_triangle(a, b, perimeter): 81 | # Given a and b, find the third side 82 | third_side = perimeter - a - b 83 | 84 | # Check triangle inequality 85 | if (a + b > third_side) and (a + third_side > b) and (b + third_side > a): 86 | return True 87 | return False 88 | 89 | # Given sides 90 | a = 12 91 | b = 7 92 | 93 | # Given perimeters 94 | perimeters = [29, 34, 37, 38] 95 | 96 | # Check which perimeter is not valid 97 | for p in perimeters: 98 | if not is_valid_triangle(a, b, p): 99 | print(p) 100 | """, 101 | 102 | } 103 | ] 104 | 105 | 106 | def refine_caption(caption): 107 | if isinstance(caption, str): 108 | nonsense = ["Sure. ", 109 | "Sure, I can do that.", 110 | "Sorry, I can't help with images of people yet.", 111 | "I can't process this file.", 112 | "I'm unable to help you with that, as I'm only a language model and don't have the necessary information or abilities.", 113 | "I'm not programmed to assist with that.", 114 | "Please let me know if you have any other questions.", 115 | "I hope this is helpful!", 116 | "I hope this helps!"] 117 | for non in nonsense: 118 | caption = caption.replace(non, "").strip() 119 | caption = caption.replace(" ", " ").strip() 120 | else: 121 | caption = "" 122 | return caption 123 | 124 | 125 | def refine_ocr(ocr): 126 | """ 127 | [ ( 128 | [161, 39], [766, 39], [766, 120], [161, 120]], 129 | 'The spring force does', 130 | 0.912845069753024 131 | ), 132 | ] 133 | --> 134 | [ ( 135 | [161, 39], 136 | 'The spring force does', 137 | ), 138 | ] 139 | """ 140 | try: 141 | ocr = eval(ocr) 142 | if len(ocr) > 0: 143 | ocr = [([int(e[0][0][0]), int(e[0][0][1])], e[1]) for e in ocr] 144 | ocr = str(ocr) 145 | else: 146 | ocr = "" 147 | except: 148 | ocr = "" 149 | return ocr 150 | 151 | 152 | def create_one_query(problem, examples, shot_num, shot_type, use_caption, use_ocr): 153 | 154 | 155 | ### [1] Demo prompt 156 | if shot_num == 0: 157 | demo_prompt = "" 158 | else: 159 | demos = [] 160 | shot_num = min(shot_num, len(examples)) 161 | for example in examples[:shot_num]: 162 | prompt = "" 163 | 164 | # question 165 | prompt += f"Question: {example['question']}" 166 | 167 | # choices 168 | if "choices" in example: 169 | texts = ["Choices:"] 170 | for i, choice in enumerate(example['choices']): 171 | texts.append(f"({chr(ord('A')+i)}) {choice}") 172 | prompt += "\n" + "\n".join(texts) 173 | 174 | # caption 175 | if use_caption: 176 | caption = example['caption'] if 'caption' in example else "" 177 | if caption != "": 178 | prompt += "\n" + f"Image description: {caption}" 179 | 180 | # ocr 181 | if use_ocr: 182 | ocr = example['ocr'] if 'ocr' in example else "" 183 | if ocr != "": 184 | prompt += "\n" + f"Image detected text: {ocr}" 185 | 186 | # solution 187 | if shot_type == 'solution': 188 | solution = example['solution'].strip() 189 | prompt += "\n" + f"Solution: {solution}" 190 | 191 | # code 192 | if shot_type == 'code': 193 | code = example['code'].strip() 194 | prompt += "\n" + f"Python code: {code}" 195 | 196 | demos.append(prompt) 197 | 198 | demo_prompt = "\n\n".join(demos) 199 | 200 | ### [2] Test query 201 | # problem info 202 | question = problem['question'] 203 | unit = problem['unit'] 204 | choices = problem['choices'] 205 | caption = problem['caption'] 206 | ocr = problem['ocr'] 207 | precision = problem['precision'] 208 | question_type = problem['question_type'] 209 | answer_type = problem['answer_type'] 210 | 211 | # hint 212 | if shot_type == 'solution': 213 | if question_type == "multi_choice": 214 | assert answer_type == "text" 215 | hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end." 216 | else: 217 | assert answer_type in ["integer", "float", "list"] 218 | if answer_type == "integer": 219 | hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end." 220 | 221 | elif answer_type == "float" and precision == 1: 222 | hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end." 223 | 224 | elif answer_type == "float" and precision == 2: 225 | hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end." 226 | 227 | elif answer_type == "list": 228 | hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end." 229 | else: 230 | assert shot_type == 'code' 231 | hint_text = "Hint: Please generate a python code to solve the problem" 232 | 233 | # question 234 | question_text = f"Question: {question}" 235 | if unit: 236 | question_text += f" (Unit: {unit})" 237 | 238 | # choices 239 | if choices: 240 | # choices: (A) 1.2 (B) 1.3 (C) 1.4 (D) 1.5 241 | texts = ["Choices:"] 242 | for i, choice in enumerate(choices): 243 | texts.append(f"({chr(ord('A')+i)}) {choice}") 244 | choices_text = "\n".join(texts) 245 | else: 246 | choices_text = "" 247 | 248 | # caption 249 | caption_text = "" 250 | if use_caption and caption != "": 251 | caption_text = f"Image description: {caption}" 252 | 253 | # ocr 254 | ocr_text = "" 255 | if use_ocr and ocr != "": 256 | ocr_text = f"Image detected text: {ocr}" 257 | 258 | # prompt 259 | if shot_type == 'solution': 260 | prompt = "Solution: " 261 | else: 262 | assert shot_type == 'code' 263 | prompt = "Python code: " 264 | 265 | elements = [question_text, choices_text, caption_text, ocr_text, hint_text, prompt] 266 | test_query = "\n".join([e for e in elements if e != ""]) 267 | 268 | ### [3] Final query 269 | query = demo_prompt + "\n\n" + test_query 270 | query = query.strip() 271 | return query 272 | #2': 'Question: what is the total volume of the measuring cup? (Unit: g)\n 273 | # Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n 274 | # Solution:', 275 | # 3': 'Question: △ABC的两内角平分线OB、OC相交于点O,若∠A=110°,则∠BOC=()\nChoices:\n(A) 135°\n(B) 140°\n(C) 145°\n(D) 150°\n 276 | # Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n 277 | # Solution:', 278 | 279 | def create_query_data(data, caption_data, ocr_data, args): 280 | query_data = {} 281 | 282 | for pid, problem in data.items(): 283 | if pid in caption_data: 284 | caption = caption_data[pid] 285 | caption = refine_caption(caption) 286 | problem['caption'] = caption 287 | else: 288 | problem['caption'] = "" 289 | 290 | if pid in ocr_data: 291 | ocr = ocr_data[pid] 292 | ocr = refine_ocr(ocr) 293 | problem['ocr'] = ocr 294 | else: 295 | problem['ocr'] = [] 296 | 297 | query = create_one_query( 298 | problem = problem, 299 | examples = shot_examples, 300 | shot_num = args.shot_num, 301 | shot_type = args.shot_type, 302 | use_caption = args.use_caption, 303 | use_ocr = args.use_ocr 304 | ) 305 | query_data[pid] = query 306 | 307 | return query_data 308 | 309 | 310 | if __name__ == "__main__": 311 | for example in shot_examples: 312 | print("----------------------------------------") 313 | print("\nQuestion:", example['question']) 314 | if "choices" in example: 315 | print("\nChoices:", example['choices']) 316 | print("\nCaption:", example['caption']) 317 | if "ocr" in example: 318 | print("\nOCR:", example['ocr']) 319 | print("\nSolution:", example['solution']) 320 | print("\nCode:", example['code']) 321 | 322 | # execute code 323 | exec(example['code']) 324 | 325 | 326 | -------------------------------------------------------------------------------- /evaluation_mathvista/calculate_score.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import argparse 4 | import pandas as pd 5 | 6 | # !pip install python-Levenshtein 7 | from Levenshtein import distance 8 | 9 | import sys 10 | 11 | sys.path.append('../') 12 | from utilities import * 13 | 14 | 15 | def get_most_similar(prediction, choices): 16 | """ 17 | Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction 18 | """ 19 | distances = [distance(prediction, choice) for choice in choices] 20 | ind = distances.index(min(distances)) 21 | return choices[ind] 22 | # return min(choices, key=lambda choice: distance(prediction, choice)) 23 | 24 | 25 | def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision): 26 | """ 27 | Normalize the extracted answer to match the answer type 28 | """ 29 | if question_type == 'multi_choice': 30 | # make sure the extraction is a string 31 | if isinstance(extraction, str): 32 | extraction = extraction.strip() 33 | else: 34 | try: 35 | extraction = str(extraction) 36 | except: 37 | extraction = "" 38 | 39 | # extract "A" from "(A) text" 40 | letter = re.findall(r'\(([a-zA-Z])\)', extraction) 41 | if len(letter) > 0: 42 | extraction = letter[0].upper() 43 | 44 | options = [chr(ord('A') + i) for i in range(len(choices))] 45 | 46 | if extraction in options: 47 | # convert option letter to text, e.g. "A" -> "text" 48 | ind = options.index(extraction) 49 | extraction = choices[ind] 50 | else: 51 | # select the most similar option 52 | extraction = get_most_similar(extraction, choices) 53 | assert extraction in choices 54 | 55 | elif answer_type == 'integer': 56 | try: 57 | extraction = str(int(float(extraction))) 58 | except: 59 | extraction = None 60 | 61 | elif answer_type == 'float': 62 | try: 63 | extraction = str(round(float(extraction), precision)) 64 | except: 65 | extraction = None 66 | 67 | elif answer_type == 'list': 68 | try: 69 | extraction = str(extraction) 70 | except: 71 | extraction = None 72 | 73 | return extraction 74 | 75 | 76 | def safe_equal(prediction, answer): 77 | """ 78 | Check if the prediction is equal to the answer, even if they are of different types 79 | """ 80 | try: 81 | if prediction == answer: 82 | return True 83 | return False 84 | except Exception as e: 85 | print(e) 86 | return False 87 | 88 | 89 | def get_acc_with_contion(res_pd, key, value): 90 | if key == 'skills': 91 | # if value in res_pd[key]: 92 | total_pd = res_pd[res_pd[key].apply(lambda x: value in x)] 93 | else: 94 | total_pd = res_pd[res_pd[key] == value] 95 | 96 | correct_pd = total_pd[total_pd['true_false'] == True] 97 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) 98 | return len(correct_pd), len(total_pd), acc 99 | 100 | 101 | if __name__ == '__main__': 102 | 103 | 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs') 106 | parser.add_argument('--output_file', type=str, default='responses.json') 107 | parser.add_argument('--score_file', type=str, default='responses_score.json') 108 | parser.add_argument('--gt_file', type=str, default='./mathvista_data/testmini.json', 109 | help='ground truth file') 110 | parser.add_argument('--number', type=int, default=-1, help='number of problems to run') 111 | parser.add_argument('--rerun', default=True, help='rerun the evaluation') 112 | parser.add_argument('--caculate_gain', default=False, help='caculate the socre gains over random guess') 113 | parser.add_argument('--random_file', type=str, default='score_random_guess.json') 114 | args = parser.parse_args() 115 | 116 | # args 117 | output_file = os.path.join(args.output_dir, args.output_file) 118 | 119 | 120 | 121 | # read json 122 | print(f"Reading {output_file}...") 123 | results = read_json(output_file) 124 | 125 | # read ground truth 126 | print(f"Reading {args.gt_file}...") 127 | gts = read_json(args.gt_file) 128 | 129 | # full pids 130 | full_pids = list(results.keys()) 131 | if args.number > 0: 132 | full_pids = full_pids[:min(args.number, len(full_pids))] 133 | print("Number of testing problems:", len(full_pids)) 134 | 135 | ## [1] Evaluate if the prediction is true or false 136 | print("\nEvaluating the predictions...") 137 | update_json_flag = False 138 | for pid in full_pids: 139 | problem = results[pid] 140 | # print(problem) 141 | 142 | if args.rerun: 143 | if 'prediction' in problem: 144 | del problem['prediction'] 145 | if 'true_false' in problem: 146 | del problem['true_false'] 147 | 148 | choices = problem['choices'] 149 | question_type = problem['question_type'] 150 | answer_type = problem['answer_type'] 151 | precision = problem['precision'] 152 | extraction = problem['extraction'] 153 | 154 | if 'answer' in problem: 155 | answer = problem['answer'] 156 | else: 157 | answer = gts[pid]['answer'] 158 | problem['answer'] = answer 159 | 160 | # normalize the extracted answer to match the answer type 161 | prediction = normalize_extracted_answer(extraction, choices, question_type, answer_type, precision) 162 | 163 | # verify the prediction is true or false 164 | true_false = safe_equal(prediction, answer) 165 | 166 | # update the problem 167 | if "true_false" not in problem: 168 | update_json_flag = True 169 | 170 | elif true_false != problem['true_false']: 171 | update_json_flag = True 172 | 173 | if "prediction" not in problem: 174 | update_json_flag = True 175 | 176 | elif prediction != problem['prediction']: 177 | update_json_flag = True 178 | 179 | problem['prediction'] = prediction 180 | problem['true_false'] = true_false 181 | 182 | # save the updated json 183 | if update_json_flag: 184 | print("\n!!!Some problems are updated.!!!") 185 | print(f"\nSaving {output_file}...") 186 | save_json(results, output_file) 187 | 188 | ## [2] Calculate the average accuracy 189 | total = len(full_pids) 190 | correct = 0 191 | for pid in full_pids: 192 | if results[pid]['true_false']: 193 | correct += 1 194 | accuracy = str(round(correct / total * 100, 2)) 195 | print(f"\nCorrect: {correct}, Total: {total}, Accuracy: {accuracy}%") 196 | 197 | scores = {"average": {"accuracy": accuracy, "correct": correct, "total": total}} 198 | 199 | ## [3] Calculate the fine-grained accuracy scores 200 | 201 | # merge the 'metadata' attribute into the data 202 | for pid in results: 203 | results[pid].update(results[pid].pop('metadata')) 204 | 205 | # convert the data to a pandas DataFrame 206 | df = pd.DataFrame(results).T 207 | 208 | print(len(df)) 209 | print("Number of test problems:", len(df)) 210 | # assert len(df) == 1000 # Important!!! 211 | 212 | # asign the target keys for evaluation 213 | target_keys = ['question_type', 'answer_type', 'language', 'source', 'category', 'task', 'context', 'grade', 214 | 'skills'] 215 | 216 | for key in target_keys: 217 | print(f"\nType: [{key}]") 218 | # get the unique values of the key 219 | if key == 'skills': 220 | # the value is a list 221 | values = [] 222 | for i in range(len(df)): 223 | values += df[key][i] 224 | values = list(set(values)) 225 | else: 226 | values = df[key].unique() 227 | # print(values) 228 | 229 | # calculate the accuracy for each value 230 | scores[key] = {} 231 | for value in values: 232 | correct, total, acc = get_acc_with_contion(df, key, value) 233 | if total > 0: 234 | print(f"[{value}]: {acc}% ({correct}/{total})") 235 | scores[key][value] = {"accuracy": acc, "correct": correct, "total": total} 236 | 237 | # sort the scores by accuracy 238 | scores[key] = dict(sorted(scores[key].items(), key=lambda item: float(item[1]['accuracy']), reverse=True)) 239 | 240 | # save the scores 241 | scores_file = os.path.join(args.output_dir, args.score_file) 242 | print(f"\nSaving {scores_file}...") 243 | save_json(scores, scores_file) 244 | print("\nDone!") 245 | 246 | # [4] Calculate the score gains over random guess 247 | if args.caculate_gain: 248 | random_file = os.path.join(args.output_dir, args.random_file) 249 | random_scores = json.load(open(random_file)) 250 | 251 | print("\nCalculating the score gains...") 252 | for key in scores: 253 | if key == 'average': 254 | gain = round(float(scores[key]['accuracy']) - float(random_scores[key]['accuracy']), 2) 255 | scores[key]['acc_gain'] = gain 256 | else: 257 | for sub_key in scores[key]: 258 | gain = round( 259 | float(scores[key][sub_key]['accuracy']) - float(random_scores[key][sub_key]['accuracy']), 2) 260 | scores[key][sub_key]['acc_gain'] = str(gain) 261 | 262 | # save the score gains 263 | print(f"\nSaving {scores_file}...") 264 | save_json(scores, scores_file) 265 | print("\nDone!") -------------------------------------------------------------------------------- /evaluation_mathvista/ext_ans.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # pids = 852, 104, 824, 506, 540 4 | 5 | demo_prompt = """ 6 | Please read the following example. Then extract the answer from the model response and type it at the end of the prompt. 7 | 8 | Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end. 9 | Question: Which number is missing? 10 | 11 | Model response: The number missing in the sequence is 14. 12 | 13 | Extracted answer: 14 14 | 15 | Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end. 16 | Question: What is the fraction of females facing the camera? 17 | 18 | Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera. 19 | 20 | Extracted answer: 0.6 21 | 22 | Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end. 23 | Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $) 24 | 25 | Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. 26 | 27 | Extracted answer: 1.45 28 | 29 | Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end. 30 | Question: Between which two years does the line graph saw its maximum peak? 31 | 32 | Model response: The line graph saw its maximum peak between 2007 and 2008. 33 | 34 | Extracted answer: [2007, 2008] 35 | 36 | Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. 37 | Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5 38 | 39 | Model response: The correct answer is (B) 8/11. 40 | 41 | Extracted answer: B 42 | """ -------------------------------------------------------------------------------- /evaluation_mathvista/extract_answer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | 8 | import sys 9 | 10 | sys.path.append('../') 11 | from utilities import * 12 | 13 | # OpenAI 14 | import openai 15 | 16 | 17 | api_key = "" # 18 | headers = { 19 | "Content-Type": "application/json", 20 | "Authorization": f"Bearer {api_key}" 21 | } 22 | # load demo prompt 23 | from ext_ans import demo_prompt 24 | 25 | 26 | def verify_extraction(extraction): 27 | extraction = extraction.strip() 28 | if extraction == "" or extraction == None: 29 | return False 30 | return True 31 | 32 | 33 | def create_test_prompt(demo_prompt, query, response): #few 34 | demo_prompt = demo_prompt.strip() 35 | test_prompt = f"{query}\n\n{response}" 36 | full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " 37 | return full_prompt 38 | 39 | 40 | def extract_answer(response, problem, quick_extract=False, llm_engine): 41 | question_type = problem['question_type'] 42 | answer_type = problem['answer_type'] 43 | choices = problem['choices'] 44 | query = problem['query'] 45 | 46 | if response == "": 47 | return "" 48 | 49 | if question_type == 'multi_choice' and response in choices: 50 | return response 51 | 52 | if answer_type == "integer": 53 | try: 54 | extraction = int(response) 55 | return str(extraction) 56 | except: 57 | pass 58 | 59 | if answer_type == "float": 60 | try: 61 | extraction = str(float(response)) 62 | return extraction 63 | except: 64 | pass 65 | 66 | # quick extraction 67 | if quick_extract: 68 | print("Quickly extracting answer...") 69 | try: 70 | result = re.search(r'The answer is "(.*)"\.', response) 71 | if result: 72 | extraction = result.group(1) 73 | return extraction 74 | except: 75 | pass 76 | 77 | # general extraction 78 | try: 79 | full_prompt = create_test_prompt(demo_prompt, query, response) 80 | extraction = get_chat_response_new(full_prompt, headers, model=llm_engine) 81 | return extraction 82 | except Exception as e: 83 | print(e) 84 | print(f"Error in extracting answer for {pid}") 85 | 86 | return response 87 | 88 | 89 | def extract_answer_quick(response, problem, quick_extract=False): 90 | question_type = problem['question_type'] 91 | answer_type = problem['answer_type'] 92 | choices = problem['choices'] 93 | query = problem['query'] 94 | 95 | if response == "": 96 | return "" 97 | 98 | if question_type == 'multi_choice' and response in choices: 99 | return response 100 | 101 | if answer_type == "integer": 102 | try: 103 | extraction = int(response) 104 | return str(extraction) 105 | except: 106 | pass 107 | 108 | if answer_type == "float": 109 | try: 110 | extraction = str(float(response)) 111 | return extraction 112 | except: 113 | pass 114 | 115 | # quick extraction 116 | if quick_extract: 117 | print("Quickly extracting answer...") 118 | try: 119 | result = response.split('The answer is ') 120 | if result: 121 | #extraction = result.group(1) 122 | extraction = result[1] 123 | return extraction 124 | except: 125 | pass 126 | 127 | 128 | return response 129 | 130 | 131 | if __name__ == '__main__': 132 | 133 | parser = argparse.ArgumentParser() 134 | # input 135 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs') 136 | parser.add_argument('--output_file', type=str, default='responses.json') 137 | parser.add_argument('--response_label', type=str, default='response', 138 | help='response label for the input file') 139 | # model 140 | parser.add_argument('--llm_engine', type=str, default='gpt-3.5-turbo', help='llm engine', 141 | choices=['gpt-3.5-turbo', 'gpt-3.5', 'gpt-4', 'gpt-4-0314', 'gpt-4-0613']) 142 | parser.add_argument('--number', type=int, default=-1, help='number of problems to run') 143 | parser.add_argument('--quick_extract', default=False, help='use rules to extract answer for some problems') 144 | parser.add_argument('--rerun', default=False, help='rerun the answer extraction') 145 | # output 146 | parser.add_argument('--save_every', type=int, default=10, help='save every n problems') 147 | parser.add_argument('--output_label', type=str, default='', help='label for the output file') 148 | args = parser.parse_args() 149 | 150 | # args 151 | label = args.response_label 152 | result_file = os.path.join(args.output_dir, args.output_file) 153 | 154 | if args.output_label != '': 155 | output_file = result_file.replace('.json', f'_{args.output_label}.json') 156 | else: 157 | output_file = result_file 158 | 159 | # read results 160 | print(f"Reading {result_file}...") 161 | results = read_json(result_file) 162 | 163 | # full pids 164 | full_pids = list(results.keys()) 165 | if args.number > 0: 166 | full_pids = full_pids[:min(args.number, len(full_pids))] 167 | print("Number of testing problems:", len(full_pids)) 168 | 169 | # test pids 170 | if args.rerun: 171 | test_pids = full_pids 172 | else: 173 | test_pids = [] 174 | for pid in full_pids: 175 | # print(pid) 176 | if 'extraction' not in results[pid] or not verify_extraction(results[pid]['extraction']): 177 | test_pids.append(pid) 178 | 179 | test_num = len(test_pids) 180 | print("Number of problems to run:", test_num) 181 | # print(test_pids) 182 | 183 | # tqdm, enumerate results 184 | for i, pid in enumerate(tqdm(test_pids)): 185 | problem = results[pid] 186 | 187 | assert label in problem 188 | response = problem[label] 189 | 190 | extraction = extract_answer(response, problem, args.quick_extract, args.llm_engine) 191 | results[pid]['extraction'] = extraction 192 | 193 | if i % args.save_every == 0 or i == test_num - 1: 194 | print(f"Saving results to {output_file}...") 195 | save_json(results, output_file) 196 | print(f"Results saved.") 197 | -------------------------------------------------------------------------------- /evaluation_mathvista/response.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | import io 5 | import time 6 | import argparse 7 | 8 | from tqdm import tqdm 9 | 10 | import sys 11 | 12 | sys.path.append('../') 13 | from utilities import * 14 | 15 | from build_query import create_query_data 16 | from llava.model.builder import load_pretrained_model 17 | from llava.mm_utils import get_model_name_from_path 18 | from llava.eval.run_llava import eval_model, evalmodel 19 | from llava.utils import disable_torch_init 20 | 21 | 22 | def verify_response(response): 23 | if isinstance(response, str): 24 | response = response.strip() 25 | if response == "" or response == None: 26 | return False 27 | if "Response Error" in response: 28 | return False 29 | return True 30 | 31 | 32 | def evaluate_code(code_string): 33 | # execute_code_and_capture_output 34 | # Backup the original stdout 35 | old_stdout = sys.stdout 36 | 37 | # Redirect stdout to capture the output 38 | new_stdout = io.StringIO() 39 | sys.stdout = new_stdout 40 | 41 | # Try executing the code and capture any exception 42 | error = None 43 | try: 44 | exec(code_string) 45 | except Exception as e: 46 | error = e 47 | 48 | # Restore the original stdout 49 | sys.stdout = old_stdout 50 | 51 | # Get the captured output 52 | captured_output = new_stdout.getvalue() 53 | if isinstance(captured_output, str): 54 | captured_output = captured_output.strip() 55 | 56 | # Return the captured output or error 57 | return captured_output, error 58 | 59 | 60 | if __name__ == '__main__': 61 | parser = argparse.ArgumentParser() 62 | # input 63 | parser.add_argument('--data_dir', type=str, default='./mathvista_data') 64 | parser.add_argument('--input_file', type=str, default='testmini.json') 65 | # output 66 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs') 67 | parser.add_argument('--output_file', type=str, default='responses.json') 68 | # model 69 | parser.add_argument('--model_path', type=str, default='liuhaotian/llava-v1.5-13b', help='path of lora or full model') 70 | parser.add_argument('--model_base', type=str, default=None, help='liuhaotian/llava-v1.5-13b for lora, =None for full model') 71 | # query 72 | parser.add_argument('--query_file', type=str, default='query.json') 73 | parser.add_argument('--caption_file', type=str, default=None) 74 | parser.add_argument('--ocr_file', type=str, default=None) 75 | parser.add_argument('--shot_type', type=str, default='solution', help='shot type', 76 | choices=['solution', 'code']) 77 | parser.add_argument('--shot_num', type=int, default=0, help='number of shot examples') 78 | parser.add_argument('--use_caption', default=False, help='use caption data') 79 | parser.add_argument('--use_ocr', default=False, help='use ocr data') 80 | # other settings 81 | parser.add_argument('--rerun', default=False, help='rerun answer extraction for all problems') 82 | parser.add_argument('--debug', default=False, help='debug mode') 83 | args = parser.parse_args() 84 | 85 | # load data 86 | input_file = os.path.join(args.data_dir, args.input_file) 87 | print(f"Reading {input_file}...") 88 | data = read_json(input_file) 89 | 90 | # load or create query data 91 | if args.query_file: 92 | query_file = os.path.join(args.data_dir, args.query_file) 93 | if os.path.exists(query_file): 94 | print(f"Loading existing {query_file}...") 95 | query_data = read_json(query_file) 96 | else: 97 | print("\nCreating new query...") 98 | # load caption 99 | caption_data = {} 100 | if args.use_caption: 101 | caption_file = args.caption_file 102 | if os.path.exists(caption_file): 103 | print(f"Reading {caption_file}...") 104 | try: 105 | caption_data = read_json(caption_file)["texts"] 106 | print("Caption data loaded.") 107 | except: 108 | print("Caption data not found!! Please Check.") 109 | # load ocr 110 | ocr_data = {} 111 | if args.use_ocr: 112 | ocr_file = args.ocr_file 113 | if os.path.exists(ocr_file): 114 | print(f"Reading {ocr_file}...") 115 | try: 116 | ocr_data = read_json(ocr_file)["texts"] 117 | print("OCR data loaded.") 118 | except: 119 | print("OCR data not found!! Please Check.") 120 | # create query 121 | query_data = create_query_data(data, caption_data, ocr_data, args) 122 | 123 | #print(query_data) 124 | 125 | # output file 126 | os.makedirs(args.output_dir, exist_ok=True) 127 | output_file = os.path.join(args.output_dir, args.output_file) 128 | 129 | # load results 130 | if os.path.exists(output_file): 131 | print("\nResults already exist.") 132 | print(f"Reading {output_file}...") 133 | results = read_json(output_file) 134 | else: 135 | results = {} 136 | 137 | model_path = args.model_path 138 | model_base = args.model_base 139 | 140 | tokenizer, model, image_processor, context_len = load_pretrained_model( 141 | model_path=model_path, 142 | model_base=model_base, 143 | model_name=get_model_name_from_path(model_path) 144 | ) 145 | 146 | disable_torch_init() 147 | model_name = get_model_name_from_path(model_path) 148 | ## 149 | 150 | # build final test pid list 151 | test_pids = list(data.keys()) 152 | print("\nNumber of test problems in total:", len(test_pids)) 153 | 154 | skip_pids = [] 155 | if not args.rerun: 156 | print("\nRemoving problems with existing valid response...") 157 | for pid in test_pids: 158 | # print(f"Checking {pid}...") 159 | if pid in results and 'response' in results[pid]: 160 | response = results[pid]['response'] 161 | if verify_response(response): 162 | # print(f"Valid response found for {pid}.") 163 | skip_pids.append(pid) 164 | else: 165 | print("\nRerun answer extraction for all problems...") 166 | 167 | test_pids = [pid for pid in test_pids if pid not in skip_pids] 168 | print("Number of test problems to run:", len(test_pids)) 169 | # print(test_pids) 170 | 171 | # tqdm, enumerate results 172 | for _, pid in enumerate(tqdm(test_pids)): 173 | problem = data[pid] 174 | query = query_data[pid] 175 | image = problem['image'] 176 | image_path = os.path.join(args.data_dir, image) 177 | 178 | if args.debug: 179 | print("--------------------------------------------------------------") 180 | print(f"\nGenerating response for {pid}...") 181 | try: 182 | 183 | args_llava = type('Args', (), { 184 | "model_path": model_path, 185 | "model_base": None, 186 | "model_name": get_model_name_from_path(model_path), 187 | "query": query, 188 | "conv_mode": None, 189 | "image_file": image_path, 190 | "sep": ",", 191 | "temperature": 0.2, 192 | "top_p": None, 193 | "num_beams": 1, 194 | "max_new_tokens": 512 195 | })() 196 | response = evalmodel(args_llava, model_name, tokenizer, model, image_processor, context_len) 197 | results[pid] = problem 198 | results[pid]['query'] = query 199 | if args.shot_type == 'solution': 200 | results[pid]['response'] = response 201 | else: 202 | output, error = evaluate_code(response) 203 | results[pid]['response'] = response 204 | results[pid]['execution'] = output 205 | results[pid]['error'] = str(error) 206 | if args.debug: 207 | print(f"\n#Query: \n{query}") 208 | print(f"\n#Response: \n{response}") 209 | except Exception as e: 210 | print(e) 211 | print(f"Error in extracting answer for {pid}") 212 | results[pid]['error'] = e 213 | 214 | try: 215 | print(f"Saving results to {output_file}...") 216 | save_json(results, output_file) 217 | print(f"Results saved.") 218 | except Exception as e: 219 | print(e) 220 | print(f"Error in saving {output_file}") 221 | -------------------------------------------------------------------------------- /evaluation_mathvista/utilities.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import time 5 | import pickle 6 | import openai 7 | import re 8 | 9 | import requests 10 | from word2number import w2n 11 | 12 | 13 | 14 | def create_dir(output_dir): 15 | if not os.path.exists(output_dir): 16 | os.makedirs(output_dir) 17 | 18 | 19 | def read_csv(file): 20 | data = [] 21 | with open(file, 'r') as f: 22 | for line in f: 23 | data.append(line.strip()) 24 | return data 25 | 26 | 27 | def read_pandas_csv(csv_path): 28 | # read a pandas csv sheet 29 | import pandas as pd 30 | df = pd.read_csv(csv_path) 31 | return df 32 | 33 | 34 | def read_json(path): 35 | with open(path, 'r', encoding='utf-8') as f: 36 | return json.load(f) 37 | 38 | 39 | def read_jsonl(file): 40 | with open(file, 'r') as f: 41 | data = [json.loads(line) for line in f] 42 | return data 43 | 44 | 45 | def read_pickle(path): 46 | with open(path, 'rb') as f: 47 | return pickle.load(f) 48 | 49 | 50 | def save_json(data, path): 51 | with open(path, 'w') as f: 52 | json.dump(data, f, indent=4) 53 | 54 | 55 | def save_array_img(path, image): 56 | cv2.imwrite(path, image) 57 | 58 | 59 | def contains_digit(text): 60 | # check if text contains a digit 61 | if any(char.isdigit() for char in text): 62 | return True 63 | return False 64 | 65 | 66 | def contains_number_word(text): 67 | # check if text contains a number word 68 | ignore_words = ["a", "an", "point"] 69 | words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text 70 | for word in words: 71 | if word in ignore_words: 72 | continue 73 | try: 74 | w2n.word_to_num(word) 75 | return True # If the word can be converted to a number, return True 76 | except ValueError: 77 | continue # If the word can't be converted to a number, continue with the next word 78 | 79 | # check if text contains a digit 80 | if any(char.isdigit() for char in text): 81 | return True 82 | 83 | return False # If none of the words could be converted to a number, return False 84 | 85 | 86 | def contains_quantity_word(text, special_keep_words=[]): 87 | # check if text contains a quantity word 88 | quantity_words = ["most", "least", "fewest" 89 | "more", "less", "fewer", 90 | "largest", "smallest", "greatest", 91 | "larger", "smaller", "greater", 92 | "highest", "lowest", "higher", "lower", 93 | "increase", "decrease", 94 | "minimum", "maximum", "max", "min", 95 | "mean", "average", "median", 96 | "total", "sum", "add", "subtract", 97 | "difference", "quotient", "gap", 98 | "half", "double", "twice", "triple", 99 | "square", "cube", "root", 100 | "approximate", "approximation", 101 | "triangle", "rectangle", "circle", "square", "cube", "sphere", "cylinder", "cone", "pyramid", 102 | "multiply", "divide", 103 | "percentage", "percent", "ratio", "proportion", "fraction", "rate", 104 | ] 105 | 106 | quantity_words += special_keep_words # dataset specific words 107 | 108 | words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text 109 | if any(word in quantity_words for word in words): 110 | return True 111 | 112 | return False # If none of the words could be converted to a number, return False 113 | 114 | 115 | def is_bool_word(text): 116 | if text in ["Yes", "No", "True", "False", 117 | "yes", "no", "true", "false", 118 | "YES", "NO", "TRUE", "FALSE"]: 119 | return True 120 | return False 121 | 122 | 123 | def is_digit_string(text): 124 | # remove ".0000" 125 | text = text.strip() 126 | text = re.sub(r'\.0+$', '', text) 127 | try: 128 | int(text) 129 | return True 130 | except ValueError: 131 | return False 132 | 133 | 134 | def is_float_string(text): 135 | # text is a float string if it contains a "." and can be converted to a float 136 | if "." in text: 137 | try: 138 | float(text) 139 | return True 140 | except ValueError: 141 | return False 142 | return False 143 | 144 | 145 | def copy_image(image_path, output_image_path): 146 | from shutil import copyfile 147 | copyfile(image_path, output_image_path) 148 | 149 | 150 | def copy_dir(src_dir, dst_dir): 151 | from shutil import copytree 152 | # copy the source directory to the target directory 153 | copytree(src_dir, dst_dir) 154 | 155 | 156 | import PIL.Image as Image 157 | 158 | 159 | def get_image_size(img_path): 160 | img = Image.open(img_path) 161 | width, height = img.size 162 | return width, height 163 | 164 | 165 | 166 | def get_chat_response(promot, api_key, model="gpt-3.5-turbo", temperature=0, max_tokens=256, n=1, patience=10000000, 167 | sleep_time=0): 168 | print(api_key) 169 | messages = [ 170 | {"role": "user", "content": promot}, 171 | ] 172 | # print("I am here") 173 | while patience > 0: 174 | patience -= 1 175 | try: 176 | response = openai.ChatCompletion.create(model=model, 177 | messages=messages, 178 | api_key=api_key, 179 | temperature=temperature, 180 | max_tokens=max_tokens, 181 | n=n) 182 | if n == 1: 183 | prediction = response['choices'][0]['message']['content'].strip() 184 | if prediction != "" and prediction != None: 185 | return prediction 186 | else: 187 | prediction = [choice['message']['content'].strip() for choice in response['choices']] 188 | if prediction[0] != "" and prediction[0] != None: 189 | return prediction 190 | 191 | except Exception as e: 192 | if "Rate limit" not in str(e): 193 | print(e) 194 | 195 | if "Please reduce the length of the messages" in str(e): 196 | print("!!Reduce promot size") 197 | # reduce input prompt and keep the tail 198 | new_size = int(len(promot) * 0.9) 199 | new_start = len(promot) - new_size 200 | promot = promot[new_start:] 201 | messages = [ 202 | {"role": "user", "content": promot}, 203 | ] 204 | 205 | if sleep_time > 0: 206 | time.sleep(sleep_time) 207 | return "" 208 | 209 | 210 | 211 | def get_chat_response_new(promot, headers, model="gpt-3.5-turbo", temperature=0, max_tokens=256, n=1, patience=10000000, 212 | sleep_time=0): 213 | messages = [ 214 | {"role": "user", "content": promot}, 215 | ] 216 | 217 | while patience > 0: 218 | patience -= 1 219 | try: 220 | 221 | data = {"model": model, "messages": messages, "temperature":0.0} 222 | 223 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) 224 | response = response.json() 225 | 226 | if n == 1: 227 | prediction = response['choices'][0]['message']['content'].strip() 228 | if prediction != "" and prediction != None: 229 | return prediction 230 | else: 231 | prediction = [choice['message']['content'].strip() for choice in response['choices']] 232 | if prediction[0] != "" and prediction[0] != None: 233 | return prediction 234 | 235 | except Exception as e: 236 | if "Rate limit" not in str(e): 237 | print(e) 238 | 239 | if "Please reduce the length of the messages" in str(e): 240 | print("!!Reduce promot size") 241 | # reduce input prompt and keep the tail 242 | new_size = int(len(promot) * 0.9) 243 | new_start = len(promot) - new_size 244 | promot = promot[new_start:] 245 | messages = [ 246 | {"role": "user", "content": promot}, 247 | ] 248 | 249 | if sleep_time > 0: 250 | time.sleep(sleep_time) 251 | return "" 252 | -------------------------------------------------------------------------------- /finetune_task.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | deepspeed llava/train/train_mem.py \ 3 | --deepspeed ./scripts/zero3.json \ 4 | --model_name_or_path liuhaotian/llava-v1.5-13b \ 5 | --version v1 \ 6 | --data_path ./train_samples_all_tuning.json \ 7 | --image_folder ./data_images \ 8 | --vision_tower openai/clip-vit-large-patch14-336 \ 9 | --mm_projector_type mlp2x_gelu \ 10 | --mm_vision_select_layer -2 \ 11 | --mm_use_im_start_end False \ 12 | --mm_use_im_patch_token False \ 13 | --image_aspect_ratio pad \ 14 | --group_by_modality_length True \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/llava-v1.5-13b-full-finetune \ 17 | --num_train_epochs 2 \ 18 | --per_device_train_batch_size 16 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 50000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 2e-5 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/eval/__pycache__/run_llava.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/llava/eval/__pycache__/run_llava.cpython-310.pyc -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | KeywordsStoppingCriteria, 19 | ) 20 | 21 | from PIL import Image 22 | 23 | import requests 24 | from PIL import Image 25 | from io import BytesIO 26 | import re 27 | 28 | 29 | def image_parser(args): 30 | out = args.image_file.split(args.sep) 31 | return out 32 | 33 | 34 | def load_image(image_file): 35 | if image_file.startswith("http") or image_file.startswith("https"): 36 | response = requests.get(image_file) 37 | image = Image.open(BytesIO(response.content)).convert("RGB") 38 | else: 39 | image = Image.open(image_file).convert("RGB") 40 | return image 41 | 42 | 43 | def load_images(image_files): 44 | out = [] 45 | for image_file in image_files: 46 | image = load_image(image_file) 47 | out.append(image) 48 | return out 49 | 50 | 51 | def eval_model(args): 52 | # Model 53 | disable_torch_init() 54 | 55 | model_name = get_model_name_from_path(args.model_path) 56 | tokenizer, model, image_processor, context_len = load_pretrained_model( 57 | args.model_path, args.model_base, model_name 58 | ) 59 | 60 | qs = args.query 61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 62 | if IMAGE_PLACEHOLDER in qs: 63 | if model.config.mm_use_im_start_end: 64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 65 | else: 66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 67 | else: 68 | if model.config.mm_use_im_start_end: 69 | qs = image_token_se + "\n" + qs 70 | else: 71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 72 | 73 | if "llama-2" in model_name.lower(): 74 | conv_mode = "llava_llama_2" 75 | elif "v1" in model_name.lower(): 76 | conv_mode = "llava_v1" 77 | elif "mpt" in model_name.lower(): 78 | conv_mode = "mpt" 79 | else: 80 | conv_mode = "llava_v0" 81 | 82 | if args.conv_mode is not None and conv_mode != args.conv_mode: 83 | print( 84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 85 | conv_mode, args.conv_mode, args.conv_mode 86 | ) 87 | ) 88 | else: 89 | args.conv_mode = conv_mode 90 | 91 | conv = conv_templates[args.conv_mode].copy() 92 | conv.append_message(conv.roles[0], qs) 93 | conv.append_message(conv.roles[1], None) 94 | prompt = conv.get_prompt() 95 | 96 | image_files = image_parser(args) 97 | images = load_images(image_files) 98 | images_tensor = process_images( 99 | images, 100 | image_processor, 101 | model.config 102 | ).to(model.device, dtype=torch.float16) 103 | 104 | input_ids = ( 105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 106 | .unsqueeze(0) 107 | .cuda() 108 | ) 109 | 110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 111 | keywords = [stop_str] 112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | do_sample=True if args.temperature > 0 else False, 119 | temperature=args.temperature, 120 | top_p=args.top_p, 121 | num_beams=args.num_beams, 122 | max_new_tokens=args.max_new_tokens, 123 | use_cache=True, 124 | stopping_criteria=[stopping_criteria], 125 | ) 126 | 127 | input_token_len = input_ids.shape[1] 128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 129 | if n_diff_input_output > 0: 130 | print( 131 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 132 | ) 133 | outputs = tokenizer.batch_decode( 134 | output_ids[:, input_token_len:], skip_special_tokens=True 135 | )[0] 136 | outputs = outputs.strip() 137 | if outputs.endswith(stop_str): 138 | outputs = outputs[: -len(stop_str)] 139 | outputs = outputs.strip() 140 | print(outputs) 141 | 142 | def evalmodel(args, model_name, tokenizer, model, image_processor, context_len): 143 | # Model 144 | #disable_torch_init() 145 | 146 | # model_name = get_model_name_from_path(args.model_path) 147 | # tokenizer, model, image_processor, context_len = load_pretrained_model( 148 | # args.model_path, args.model_base, model_name 149 | # ) 150 | 151 | qs = args.query 152 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 153 | if IMAGE_PLACEHOLDER in qs: 154 | if model.config.mm_use_im_start_end: 155 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 156 | else: 157 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 158 | else: 159 | if model.config.mm_use_im_start_end: 160 | qs = image_token_se + "\n" + qs 161 | else: 162 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 163 | 164 | if "llama-2" in model_name.lower(): 165 | conv_mode = "llava_llama_2" 166 | elif "v1" in model_name.lower(): 167 | conv_mode = "llava_v1" 168 | elif "mpt" in model_name.lower(): 169 | conv_mode = "mpt" 170 | else: 171 | conv_mode = "llava_v0" 172 | 173 | if args.conv_mode is not None and conv_mode != args.conv_mode: 174 | print( 175 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 176 | conv_mode, args.conv_mode, args.conv_mode 177 | ) 178 | ) 179 | else: 180 | args.conv_mode = conv_mode 181 | 182 | conv = conv_templates[args.conv_mode].copy() 183 | conv.append_message(conv.roles[0], qs) 184 | conv.append_message(conv.roles[1], None) 185 | prompt = conv.get_prompt() 186 | 187 | image_files = image_parser(args) 188 | images = load_images(image_files) 189 | images_tensor = process_images( 190 | images, 191 | image_processor, 192 | model.config 193 | ).to(model.device, dtype=torch.float16) 194 | 195 | input_ids = ( 196 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 197 | .unsqueeze(0) 198 | .cuda() 199 | ) 200 | 201 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 202 | keywords = [stop_str] 203 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 204 | 205 | with torch.inference_mode(): 206 | output_ids = model.generate( 207 | input_ids, 208 | images=images_tensor, 209 | do_sample=True if args.temperature > 0 else False, 210 | temperature=args.temperature, 211 | top_p=args.top_p, 212 | num_beams=args.num_beams, 213 | max_new_tokens=args.max_new_tokens, 214 | use_cache=True, 215 | stopping_criteria=[stopping_criteria], 216 | ) 217 | 218 | input_token_len = input_ids.shape[1] 219 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 220 | if n_diff_input_output > 0: 221 | print( 222 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids" 223 | ) 224 | outputs = tokenizer.batch_decode( 225 | output_ids[:, input_token_len:], skip_special_tokens=True 226 | )[0] 227 | outputs = outputs.strip() 228 | if outputs.endswith(stop_str): 229 | outputs = outputs[: -len(stop_str)] 230 | outputs = outputs.strip() 231 | 232 | #print(outputs) 233 | return outputs 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 239 | parser.add_argument("--model-base", type=str, default=None) 240 | parser.add_argument("--image-file", type=str, required=True) 241 | parser.add_argument("--query", type=str, required=True) 242 | parser.add_argument("--conv-mode", type=str, default=None) 243 | parser.add_argument("--sep", type=str, default=",") 244 | parser.add_argument("--temperature", type=float, default=0.2) 245 | parser.add_argument("--top_p", type=float, default=None) 246 | parser.add_argument("--num_beams", type=int, default=1) 247 | parser.add_argument("--max_new_tokens", type=int, default=512) 248 | args = parser.parse_args() 249 | 250 | eval_model(args) 251 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | class KeywordsStoppingCriteria(StoppingCriteria): 74 | def __init__(self, keywords, tokenizer, input_ids): 75 | self.keywords = keywords 76 | self.keyword_ids = [] 77 | self.max_keyword_len = 0 78 | for keyword in keywords: 79 | cur_keyword_ids = tokenizer(keyword).input_ids 80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 81 | cur_keyword_ids = cur_keyword_ids[1:] 82 | if len(cur_keyword_ids) > self.max_keyword_len: 83 | self.max_keyword_len = len(cur_keyword_ids) 84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 85 | self.tokenizer = tokenizer 86 | self.start_len = input_ids.shape[1] 87 | 88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 91 | for keyword_id in self.keyword_ids: 92 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 93 | return True 94 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 95 | for keyword in self.keywords: 96 | if keyword in outputs: 97 | return True 98 | return False 99 | 100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 101 | outputs = [] 102 | for i in range(output_ids.shape[0]): 103 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 104 | return all(outputs) 105 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if 'llava' in model_name.lower(): 46 | # Load LLaVA model 47 | if 'lora' in model_name.lower() and model_base is None: 48 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 49 | if 'lora' in model_name.lower() and model_base is not None: 50 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 51 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 52 | print('Loading LLaVA from base model...') 53 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 54 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 55 | if model.lm_head.weight.shape[0] != token_num: 56 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 57 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 58 | 59 | print('Loading additional LLaVA weights...') 60 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 61 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 62 | else: 63 | # this is probably from HF Hub 64 | from huggingface_hub import hf_hub_download 65 | def load_from_hf(repo_id, filename, subfolder=None): 66 | cache_file = hf_hub_download( 67 | repo_id=repo_id, 68 | filename=filename, 69 | subfolder=subfolder) 70 | return torch.load(cache_file, map_location='cpu') 71 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 72 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 73 | if any(k.startswith('model.model.') for k in non_lora_trainables): 74 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 75 | model.load_state_dict(non_lora_trainables, strict=False) 76 | 77 | from peft import PeftModel 78 | print('Loading LoRA weights...') 79 | model = PeftModel.from_pretrained(model, model_path) 80 | print('Merging LoRA weights...') 81 | model = model.merge_and_unload() 82 | print('Model is loaded...') 83 | elif model_base is not None: 84 | # this may be mm projector only 85 | print('Loading LLaVA from base model...') 86 | if 'mpt' in model_name.lower(): 87 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 88 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 89 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 90 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 91 | model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 92 | else: 93 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 94 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 95 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 96 | 97 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 98 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 99 | model.load_state_dict(mm_projector_weights, strict=False) 100 | else: 101 | if 'mpt' in model_name.lower(): 102 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 103 | model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 104 | else: 105 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 106 | model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 107 | else: 108 | # Load language model 109 | if model_base is not None: 110 | # PEFT model 111 | from peft import PeftModel 112 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 113 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 114 | print(f"Loading LoRA weights from {model_path}") 115 | model = PeftModel.from_pretrained(model, model_path) 116 | print(f"Merging weights") 117 | model = model.merge_and_unload() 118 | print('Convert to FP16...') 119 | model.to(torch.float16) 120 | else: 121 | use_fast = False 122 | if 'mpt' in model_name.lower(): 123 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 124 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 125 | else: 126 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 127 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 128 | 129 | image_processor = None 130 | 131 | if 'llava' in model_name.lower(): 132 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 133 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 134 | if mm_use_im_patch_token: 135 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 136 | if mm_use_im_start_end: 137 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 138 | model.resize_token_embeddings(len(tokenizer)) 139 | 140 | vision_tower = model.get_vision_tower() 141 | if not vision_tower.is_loaded: 142 | vision_tower.load_model() 143 | vision_tower.to(device=device, dtype=torch.float16) 144 | image_processor = vision_tower.image_processor 145 | 146 | if hasattr(model.config, "max_sequence_length"): 147 | context_len = model.config.max_sequence_length 148 | else: 149 | context_len = 2048 150 | 151 | return tokenizer, model, image_processor, context_len 152 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 27 | 28 | 29 | class LlavaConfig(LlamaConfig): 30 | model_type = "llava" 31 | 32 | 33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 34 | config_class = LlavaConfig 35 | 36 | def __init__(self, config: LlamaConfig): 37 | super(LlavaLlamaModel, self).__init__(config) 38 | 39 | 40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaConfig 42 | 43 | def __init__(self, config): 44 | super(LlamaForCausalLM, self).__init__(config) 45 | self.model = LlavaLlamaModel(config) 46 | self.pretraining_tp = config.pretraining_tp 47 | self.vocab_size = config.vocab_size 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | position_ids: Optional[torch.LongTensor] = None, 61 | past_key_values: Optional[List[torch.FloatTensor]] = None, 62 | inputs_embeds: Optional[torch.FloatTensor] = None, 63 | labels: Optional[torch.LongTensor] = None, 64 | use_cache: Optional[bool] = None, 65 | output_attentions: Optional[bool] = None, 66 | output_hidden_states: Optional[bool] = None, 67 | images: Optional[torch.FloatTensor] = None, 68 | return_dict: Optional[bool] = None, 69 | ) -> Union[Tuple, CausalLMOutputWithPast]: 70 | 71 | if inputs_embeds is None: 72 | ( 73 | input_ids, 74 | position_ids, 75 | attention_mask, 76 | past_key_values, 77 | inputs_embeds, 78 | labels 79 | ) = self.prepare_inputs_labels_for_multimodal( 80 | input_ids, 81 | position_ids, 82 | attention_mask, 83 | past_key_values, 84 | labels, 85 | images 86 | ) 87 | 88 | return super().forward( 89 | input_ids=input_ids, 90 | attention_mask=attention_mask, 91 | position_ids=position_ids, 92 | past_key_values=past_key_values, 93 | inputs_embeds=inputs_embeds, 94 | labels=labels, 95 | use_cache=use_cache, 96 | output_attentions=output_attentions, 97 | output_hidden_states=output_hidden_states, 98 | return_dict=return_dict 99 | ) 100 | 101 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 102 | images = kwargs.pop("images", None) 103 | _inputs = super().prepare_inputs_for_generation( 104 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 105 | ) 106 | if images is not None: 107 | _inputs['images'] = images 108 | return _inputs 109 | 110 | AutoConfig.register("llava", LlavaConfig) 111 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 112 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/language_model/mpt/param_init_fns.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections.abc import Sequence 4 | from functools import partial 5 | from typing import Optional, Tuple, Union 6 | import torch 7 | from torch import nn 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): 11 | del kwargs 12 | if verbose > 1: 13 | warnings.warn(f"Initializing network using module's reset_parameters attribute") 14 | if hasattr(module, 'reset_parameters'): 15 | module.reset_parameters() 16 | 17 | def fused_init_helper_(module: nn.Module, init_fn_): 18 | _fused = getattr(module, '_fused', None) 19 | if _fused is None: 20 | raise RuntimeError(f'Internal logic error') 21 | (dim, splits) = _fused 22 | splits = (0, *splits, module.weight.size(dim)) 23 | for (s, e) in zip(splits[:-1], splits[1:]): 24 | slice_indices = [slice(None)] * module.weight.ndim 25 | slice_indices[dim] = slice(s, e) 26 | init_fn_(module.weight[slice_indices]) 27 | 28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 29 | del kwargs 30 | if verbose > 1: 31 | warnings.warn(f'If model has bias parameters they are initialized to 0.') 32 | init_div_is_residual = init_div_is_residual 33 | if init_div_is_residual is False: 34 | div_is_residual = 1.0 35 | elif init_div_is_residual is True: 36 | div_is_residual = math.sqrt(2 * n_layers) 37 | elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): 38 | div_is_residual = init_div_is_residual 39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): 40 | div_is_residual = float(init_div_is_residual) 41 | else: 42 | div_is_residual = 1.0 43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') 44 | if init_div_is_residual is not False: 45 | if verbose > 1: 46 | warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') 47 | if isinstance(module, nn.Linear): 48 | if hasattr(module, '_fused'): 49 | fused_init_helper_(module, init_fn_) 50 | else: 51 | init_fn_(module.weight) 52 | if module.bias is not None: 53 | torch.nn.init.zeros_(module.bias) 54 | if init_div_is_residual is not False and getattr(module, '_is_residual', False): 55 | with torch.no_grad(): 56 | module.weight.div_(div_is_residual) 57 | elif isinstance(module, nn.Embedding): 58 | if emb_init_std is not None: 59 | std = emb_init_std 60 | if std == 0: 61 | warnings.warn(f'Embedding layer initialized to 0.') 62 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) 63 | if verbose > 1: 64 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') 65 | elif emb_init_uniform_lim is not None: 66 | lim = emb_init_uniform_lim 67 | if isinstance(lim, Sequence): 68 | if len(lim) > 2: 69 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') 70 | if lim[0] == lim[1]: 71 | warnings.warn(f'Embedding layer initialized to {lim[0]}.') 72 | else: 73 | if lim == 0: 74 | warnings.warn(f'Embedding layer initialized to 0.') 75 | lim = [-lim, lim] 76 | (a, b) = lim 77 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) 78 | if verbose > 1: 79 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') 80 | else: 81 | emb_init_fn_ = init_fn_ 82 | emb_init_fn_(module.weight) 83 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): 84 | if verbose > 1: 85 | warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') 86 | if hasattr(module, 'weight') and module.weight is not None: 87 | torch.nn.init.ones_(module.weight) 88 | if hasattr(module, 'bias') and module.bias is not None: 89 | torch.nn.init.zeros_(module.bias) 90 | elif isinstance(module, nn.MultiheadAttention): 91 | if module._qkv_same_embed_dim: 92 | assert module.in_proj_weight is not None 93 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) 94 | assert d_model is not None 95 | _d = d_model 96 | splits = (0, _d, 2 * _d, 3 * _d) 97 | for (s, e) in zip(splits[:-1], splits[1:]): 98 | init_fn_(module.in_proj_weight[s:e]) 99 | else: 100 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) 101 | assert module.in_proj_weight is None 102 | init_fn_(module.q_proj_weight) 103 | init_fn_(module.k_proj_weight) 104 | init_fn_(module.v_proj_weight) 105 | if module.in_proj_bias is not None: 106 | torch.nn.init.zeros_(module.in_proj_bias) 107 | if module.bias_k is not None: 108 | torch.nn.init.zeros_(module.bias_k) 109 | if module.bias_v is not None: 110 | torch.nn.init.zeros_(module.bias_v) 111 | init_fn_(module.out_proj.weight) 112 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): 113 | with torch.no_grad(): 114 | module.out_proj.weight.div_(div_is_residual) 115 | if module.out_proj.bias is not None: 116 | torch.nn.init.zeros_(module.out_proj.bias) 117 | else: 118 | for _ in module.parameters(recurse=False): 119 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') 120 | 121 | def _normal_init_(std, mean=0.0): 122 | return partial(torch.nn.init.normal_, mean=mean, std=std) 123 | 124 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 125 | del kwargs 126 | init_fn_ = _normal_init_(std=std) 127 | if verbose > 1: 128 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') 129 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 130 | 131 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 132 | del kwargs 133 | if init_std is None: 134 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") 135 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 136 | 137 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 138 | del kwargs 139 | std = math.sqrt(2 / (5 * d_model)) 140 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 141 | 142 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 143 | """From section 2.3.1 of GPT-NeoX-20B: 144 | 145 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) 146 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 147 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py 148 | """ 149 | del kwargs 150 | residual_div = n_layers / math.sqrt(10) 151 | if verbose > 1: 152 | warnings.warn(f'setting init_div_is_residual to {residual_div}') 153 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 154 | 155 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 156 | del kwargs 157 | if verbose > 1: 158 | warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 159 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 160 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 161 | 162 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 163 | del kwargs 164 | if verbose > 1: 165 | warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 166 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 167 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 168 | 169 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 170 | del kwargs 171 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) 172 | if verbose > 1: 173 | warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') 174 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 175 | 176 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 177 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) 178 | if verbose > 1: 179 | warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') 180 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 181 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.utils.data import Sampler 5 | 6 | from transformers import Trainer 7 | from transformers.trainer import ( 8 | is_sagemaker_mp_enabled, 9 | get_parameter_names, 10 | has_length, 11 | ALL_LAYERNORM_LAYERS, 12 | ShardedDDPOption, 13 | logger, 14 | ) 15 | from typing import List, Optional 16 | 17 | 18 | def maybe_zero_3(param, ignore_status=False, name=None): 19 | from deepspeed import zero 20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 21 | if hasattr(param, "ds_id"): 22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 23 | if not ignore_status: 24 | print(name, 'no ignore status') 25 | with zero.GatheredParameters([param]): 26 | param = param.data.detach().cpu().clone() 27 | else: 28 | param = param.detach().cpu().clone() 29 | return param 30 | 31 | 32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 33 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 34 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 35 | return to_return 36 | 37 | 38 | def split_to_even_chunks(indices, lengths, num_chunks): 39 | """ 40 | Split a list of indices into `chunks` chunks of roughly equal lengths. 41 | """ 42 | 43 | if len(indices) % num_chunks != 0: 44 | return [indices[i::num_chunks] for i in range(num_chunks)] 45 | 46 | num_indices_per_chunk = len(indices) // num_chunks 47 | 48 | chunks = [[] for _ in range(num_chunks)] 49 | chunks_lengths = [0 for _ in range(num_chunks)] 50 | for index in indices: 51 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 52 | chunks[shortest_chunk].append(index) 53 | chunks_lengths[shortest_chunk] += lengths[index] 54 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 55 | chunks_lengths[shortest_chunk] = float("inf") 56 | 57 | return chunks 58 | 59 | 60 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 61 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 62 | assert all(l != 0 for l in lengths), "Should not have zero length." 63 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 64 | # all samples are in the same modality 65 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 66 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 67 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 68 | 69 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 70 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 71 | megabatch_size = world_size * batch_size 72 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 73 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 74 | 75 | last_mm = mm_megabatches[-1] 76 | last_lang = lang_megabatches[-1] 77 | additional_batch = last_mm + last_lang 78 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 79 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 80 | megabatches = [megabatches[i] for i in megabatch_indices] 81 | 82 | if len(additional_batch) > 0: 83 | megabatches.append(sorted(additional_batch)) 84 | 85 | return [i for megabatch in megabatches for i in megabatch] 86 | 87 | 88 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 89 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 90 | indices = torch.randperm(len(lengths), generator=generator) 91 | megabatch_size = world_size * batch_size 92 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 93 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 94 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 95 | 96 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 97 | 98 | 99 | class LengthGroupedSampler(Sampler): 100 | r""" 101 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 102 | keeping a bit of randomness. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | batch_size: int, 108 | world_size: int, 109 | lengths: Optional[List[int]] = None, 110 | generator=None, 111 | group_by_modality: bool = False, 112 | ): 113 | if lengths is None: 114 | raise ValueError("Lengths must be provided.") 115 | 116 | self.batch_size = batch_size 117 | self.world_size = world_size 118 | self.lengths = lengths 119 | self.generator = generator 120 | self.group_by_modality = group_by_modality 121 | 122 | def __len__(self): 123 | return len(self.lengths) 124 | 125 | def __iter__(self): 126 | if self.group_by_modality: 127 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 128 | else: 129 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 130 | return iter(indices) 131 | 132 | 133 | class LLaVATrainer(Trainer): 134 | 135 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 136 | if self.train_dataset is None or not has_length(self.train_dataset): 137 | return None 138 | 139 | if self.args.group_by_modality_length: 140 | lengths = self.train_dataset.modality_lengths 141 | return LengthGroupedSampler( 142 | self.args.train_batch_size, 143 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 144 | lengths=lengths, 145 | group_by_modality=True, 146 | ) 147 | else: 148 | return super()._get_train_sampler() 149 | 150 | def create_optimizer(self): 151 | """ 152 | Setup the optimizer. 153 | 154 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 155 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 156 | """ 157 | if is_sagemaker_mp_enabled(): 158 | return super().create_optimizer() 159 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 160 | return super().create_optimizer() 161 | 162 | opt_model = self.model 163 | 164 | if self.optimizer is None: 165 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 166 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 167 | if self.args.mm_projector_lr is not None: 168 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 169 | optimizer_grouped_parameters = [ 170 | { 171 | "params": [ 172 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 173 | ], 174 | "weight_decay": self.args.weight_decay, 175 | }, 176 | { 177 | "params": [ 178 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 179 | ], 180 | "weight_decay": 0.0, 181 | }, 182 | { 183 | "params": [ 184 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 185 | ], 186 | "weight_decay": self.args.weight_decay, 187 | "lr": self.args.mm_projector_lr, 188 | }, 189 | { 190 | "params": [ 191 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 192 | ], 193 | "weight_decay": 0.0, 194 | "lr": self.args.mm_projector_lr, 195 | }, 196 | ] 197 | else: 198 | optimizer_grouped_parameters = [ 199 | { 200 | "params": [ 201 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 202 | ], 203 | "weight_decay": self.args.weight_decay, 204 | }, 205 | { 206 | "params": [ 207 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 208 | ], 209 | "weight_decay": 0.0, 210 | }, 211 | ] 212 | 213 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 214 | 215 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 216 | self.optimizer = OSS( 217 | params=optimizer_grouped_parameters, 218 | optim=optimizer_cls, 219 | **optimizer_kwargs, 220 | ) 221 | else: 222 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 223 | if optimizer_cls.__name__ == "Adam8bit": 224 | import bitsandbytes 225 | 226 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 227 | 228 | skipped = 0 229 | for module in opt_model.modules(): 230 | if isinstance(module, nn.Embedding): 231 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 232 | logger.info(f"skipped {module}: {skipped/2**20}M params") 233 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 234 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 235 | logger.info(f"skipped: {skipped/2**20}M params") 236 | 237 | return self.optimizer 238 | 239 | def _save_checkpoint(self, model, trial, metrics=None): 240 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 241 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 242 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 243 | 244 | run_dir = self._get_output_dir(trial=trial) 245 | output_dir = os.path.join(run_dir, checkpoint_folder) 246 | 247 | # Only save Adapter 248 | keys_to_match = ['mm_projector', 'vision_resampler'] 249 | if getattr(self.args, "use_im_start_end", False): 250 | keys_to_match.extend(['embed_tokens', 'embed_in']) 251 | 252 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 253 | 254 | if self.args.local_rank == 0 or self.args.local_rank == -1: 255 | self.model.config.save_pretrained(output_dir) 256 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 257 | else: 258 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 259 | 260 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 261 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 262 | pass 263 | else: 264 | super(LLaVATrainer, self)._save(output_dir, state_dict) 265 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/pipeline.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.1.3" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.0.1", "torchvision==0.15.2", 17 | "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0", 19 | "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==3.35.2", "gradio_client==0.2.9", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | train = ["deepspeed==0.9.5", "ninja", "wandb"] 27 | 28 | [project.urls] 29 | "Homepage" = "https://llava-vl.github.io" 30 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 31 | 32 | [tool.setuptools.packages.find] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 34 | 35 | [tool.wheel] 36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 37 | -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /train_samples_all_tuning.json.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/train_samples_all_tuning.json.zip --------------------------------------------------------------------------------