├── LICENSE
├── README.md
├── cog.yaml
├── eval_mmmu
├── answer_dict_val.json
├── configs
│ └── llava1.5.yaml
├── mmmu_only_eval.py
├── mmmu_response.py
└── utils
│ ├── data_utils.py
│ ├── eval_utils.py
│ └── model_utils.py
├── evaluation_mathvista
├── build_query.py
├── calculate_score.py
├── ext_ans.py
├── extract_answer.py
├── mathvista_data
│ ├── annot_testmini.json
│ ├── query.json
│ └── testmini.json
├── response.py
└── utilities.py
├── finetune_task.sh
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── eval
│ ├── __pycache__
│ │ └── run_llava.cpython-310.pyc
│ └── run_llava.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── adapt_tokenizer.py
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── custom_embedding.py
│ │ │ ├── flash_attn_triton.py
│ │ │ ├── hf_prefixlm_converter.py
│ │ │ ├── meta_init_context.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_xformers_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ ├── train_mem.py
│ └── train_xformers.py
└── utils.py
├── pipeline.png
├── pyproject.toml
├── scripts
├── extract_mm_projector.py
├── merge_lora_weights.py
├── zero2.json
├── zero3.json
└── zero3_offload.json
└── train_samples_all_tuning.json.zip
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Math-LLaVA
2 |
3 | This repository contains the code, data and model for the paper titled "Math-LLaVA: Bootstrapping Mathematical Reasoning for Multimodal Large Language Models".
4 |
5 | [Paper](http://arxiv.org/abs/2406.17294v2), [Image Dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main), [Model](https://huggingface.co/Zhiqiang007/Math-LLaVA/tree/main)
6 |
7 | 
8 |
9 | ## Latest News 🔥
10 | * [2023-06-26] We released [Math-LLaVA checkpoints](https://huggingface.co/Zhiqiang007/Math-LLaVA/tree/main). The Math-LLaVA-13B model achieves **46.6%** on MathVista testmini, achieves **38.3%** on MMMU, and achieves **15.69%** on MATH-V.
11 | * [2024-06-25] Release [paper](http://arxiv.org/abs/2406.17294v2), [code](https://github.com/HZQ950419/Math-LLaVA) and [MathV360K dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main).
12 |
13 | ## Install Packages
14 | ```
15 | cd Math-LLaVA
16 | conda create -n math_llava python=3.10 -y
17 | conda activate math_llava
18 | pip install -e .
19 | ```
20 | ## Enable Deepspeed and Flash-attention
21 | ```
22 | pip install -e ".[train]"
23 | pip install flash-attn --no-build-isolation
24 | ```
25 |
26 | ## Data Preparation
27 | "train_samples_all_tuning.json" corresponds to the annotations of qa pairs for finetuning.
28 | Download [image dataset](https://huggingface.co/datasets/Zhiqiang007/MathV360K/tree/main).
29 |
30 | Place the data in the root directory or other directory.
31 | Data structure:
32 | ```
33 | ├── data_images/
34 | │ ├── TabMWP/images/
35 | │ ├── IconQA/images/
36 | │ ├── ...
37 | ├── train_samples_all_tuning.json
38 | ```
39 |
40 | ## Run full-finetuning
41 | ```
42 | sh finetune_task.sh
43 | ```
44 |
45 | ## MathVista Evaluation
46 | You can download and unzip images of MathVista using the following commands:
47 | ```
48 | cd ./evaluation_mathvista/mathvista_data
49 | wget https://huggingface.co/datasets/AI4Math/MathVista/resolve/main/images.zip
50 | unzip images.zip
51 | ```
52 | Generate the response on testmini subset:
53 | ```
54 | cd evaluation_mathvista
55 | python response.py --output_dir ./mathvista_outputs --output_file responses.json --model_path your/model/path --model_base None
56 | ```
57 | Extract the short answer text for score calculation by ChatGPT. Please refer [OpenAI API key](https://platform.openai.com/account/api-keys).
58 | ```
59 | python extract_answer.py --output_file responses.json
60 | ```
61 | Calculate the final score:
62 | ```
63 | python calculate_score.py --output_file responses.json --score_file responses_score.json
64 | ```
65 |
66 | ## MMMU Evaluation
67 | Generate the response:
68 | ```
69 | cd eval_mmmu
70 | python mmmu_response.py --output_path mmmu_eval_output.json --model_path
71 | ```
72 | Calculate the score:
73 | ```
74 | python mmmu_only_eval.py --output_path mmmu_eval_output.json --answer_path ./answer_dict_val.json
75 | ```
76 | ## Results on MathVista
77 | Accuracy scores on the testmini subset:
78 |
79 | | Model | ALL | FQA |GPS |MWP |TQA |VQA |
80 | |-----------------------|--------|--------|--------|--------|--------|--------|
81 | | miniGPT4-7B |**23.1**|**18.6**|**26.0**|**13.4**|**30.4**|**30.2**|
82 | | InstructBLIP-7B |**25.3**|**23.1**|**20.7**|**18.3**|**32.3**|**35.2**|
83 | | LLaVA-13B |**26.1**|**26.8**|**29.3**|**16.1**|**32.3**|**26.3**|
84 | | SPHINX-V1-13B |**27.5**|**23.4**|**23.1**|**21.5**|**39.9**|**34.1**|
85 | | LLaVA-1.5-13B |**27.6**|**-**|**-**|**-**|**-**|**-**|
86 | | OmniLMM-12B |**34.9**|**45.0**|**17.8**|**26.9**|**44.9**|**39.1**|
87 | | Math-LLaVA-13B |**46.6**|**37.2**|**57.7**|**56.5**|**51.3**|**33.5**|
88 |
89 |
90 |
91 | ## Results on MMMU
92 | Accuracy scores on the validation set:
93 |
94 | | Model | ALL |
95 | |-----------------------|--------|
96 | | miniGPT4-7B |**26.8**|
97 | | mPLUG-Owl-7B |**32.7**|
98 | | InstructBLIP-7B |**32.9**|
99 | | SPHINX-13B |**32.9**|
100 | | LLaVA-1.5-13B |**36.4**|
101 | | Math-LLaVA-13B |**38.3**|
102 |
103 | ## Results on MATH-V
104 | We also test on [MATH-V](https://github.com/mathvision-cuhk/MATH-V), a more challenging dataset:
105 |
106 | | Model | ALL |
107 | |-----------------------|--------|
108 | | Qwen-VL-Plus |**10.72**|
109 | | LLaVA-1.5-13B |**11.12**|
110 | | ShareGPT4V-13B |**11.88**|
111 | | InternLM-XComposer2-VL|**14.54**|
112 | | Math-LLaVA-13B |**15.69**|
113 |
114 | ## Acknowledgement
115 | The project is built on top of the amazing [LLaVA](https://github.com/haotian-liu/LLaVA) repository, [MathVista](https://github.com/lupantech/MathVista) and [MMMU](https://github.com/MMMU-Benchmark/MMMU). Thanks for their contributions!
116 |
117 |
118 | If you find our code and dataset helpful to your research, please consider citing us with this BibTeX:
119 | ```bibtex
120 | @misc{shihu2024mathllava,
121 | title={Math-LLaVA: Bootstrapping Mathematical Reasoning for Multimodal Large Language Models},
122 | author={Wenhao Shi and Zhiqiang Hu and Yi Bin and Junhua Liu and Yang Yang and See-Kiong Ng and Lidong Bing and Roy Ka-Wei Lee},
123 | year={2024},
124 | eprint={2406.17294},
125 | archivePrefix={arXiv},
126 | primaryClass={cs.CL}
127 | }
128 | ```
129 |
130 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 |
7 | python_version: "3.11"
8 |
9 | python_packages:
10 | - "torch==2.0.1"
11 | - "accelerate==0.21.0"
12 | - "bitsandbytes==0.41.0"
13 | - "deepspeed==0.9.5"
14 | - "einops-exts==0.0.4"
15 | - "einops==0.6.1"
16 | - "gradio==3.35.2"
17 | - "gradio_client==0.2.9"
18 | - "httpx==0.24.0"
19 | - "markdown2==2.4.10"
20 | - "numpy==1.26.0"
21 | - "peft==0.4.0"
22 | - "scikit-learn==1.2.2"
23 | - "sentencepiece==0.1.99"
24 | - "shortuuid==1.0.11"
25 | - "timm==0.6.13"
26 | - "tokenizers==0.13.3"
27 | - "torch==2.0.1"
28 | - "torchvision==0.15.2"
29 | - "transformers==4.31.0"
30 | - "wandb==0.15.12"
31 | - "wavedrom==2.0.3.post3"
32 | - "Pygments==2.16.1"
33 | run:
34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35 |
36 | # predict.py defines how predictions are run on your model
37 | predict: "predict.py:Predictor"
38 |
--------------------------------------------------------------------------------
/eval_mmmu/configs/llava1.5.yaml:
--------------------------------------------------------------------------------
1 | #task_instructions:
2 | #- ""
3 | #multi_choice_example_format:
4 | #- "{}
5 | #
6 | #{}
7 | #
8 | #Answer with the option's letter from the given choices directly."
9 | #
10 | #short_ans_example_format:
11 | #- "{}
12 | #
13 | #Answer the question using a single word or phrase."
14 | #temperature:
15 | #- 0
16 | task_instructions:
17 | - ""
18 | multi_choice_example_format:
19 | - "Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
20 |
21 | {}
22 |
23 | {}"
24 |
25 | short_ans_example_format:
26 | - "Hint: Please answer the question and provide the final answer at the end.
27 |
28 | {}"
29 | temperature:
30 | - 0
--------------------------------------------------------------------------------
/eval_mmmu/mmmu_only_eval.py:
--------------------------------------------------------------------------------
1 | """Parse and Evalate"""
2 | import os
3 | import json
4 |
5 | import pdb
6 | from argparse import ArgumentParser
7 | from utils.data_utils import save_json, CAT_SHORT2LONG, DOMAIN_CAT2SUB_CAT
8 | from utils.eval_utils import evaluate, parse_multi_choice_response, parse_open_response, calculate_ins_level_acc, eval_open
9 |
10 | if __name__ == '__main__':
11 |
12 | parser = ArgumentParser()
13 | parser.add_argument('--output_path', type=str, default="mmmu_eval_output.json",
14 | help="The path to model output file.")
15 | parser.add_argument('--answer_path', type=str, default="./answer_dict_val.json", help="Answer file path.")
16 | args = parser.parse_args()
17 |
18 | output_dict = json.load(open(args.output_path))
19 | answer_dict = json.load(open(args.answer_path))
20 |
21 | # group by category
22 | output_dict_w_cat = {}
23 | for data_id, parsed_pred in output_dict.items():
24 | category = "_".join(data_id.split("_")[1:-1])
25 | if category not in output_dict_w_cat:
26 | output_dict_w_cat.update({category: {}})
27 | output_dict_w_cat[category].update({data_id: parsed_pred})
28 |
29 | # group by category
30 | answer_dict_w_cat = {}
31 | for data_id, parsed_pred in answer_dict.items():
32 | category = "_".join(data_id.split("_")[1:-1])
33 | if category not in answer_dict_w_cat:
34 | answer_dict_w_cat.update({category: {}})
35 | answer_dict_w_cat[category].update({data_id: parsed_pred})
36 |
37 | evaluation_result = {}
38 |
39 | for category in CAT_SHORT2LONG.values():
40 | print("Evaluating: {}".format(category))
41 | # get cat_outputs and cat_answers
42 | try:
43 | cat_outputs = output_dict_w_cat[category]
44 | cat_answers = answer_dict_w_cat[category]
45 | except KeyError:
46 | print("Skipping {} for not found".format(category))
47 | continue
48 |
49 | exampels_to_eval = []
50 | for data_id, parsed_pred in cat_outputs.items():
51 | question_type = cat_answers[data_id]['question_type']
52 | if question_type != 'multiple-choice':
53 | parsed_pred = parse_open_response(parsed_pred)
54 | # print(parsed_pred)
55 | # print(cat_answers[data_id]['ground_truth'])
56 | correct = eval_open(cat_answers[data_id]['ground_truth'], parsed_pred)
57 | else:
58 | parsed_pred = parsed_pred
59 |
60 |
61 | exampels_to_eval.append({
62 | "id": data_id,
63 | "question_type": question_type,
64 | "answer": cat_answers[data_id]['ground_truth'],
65 | "parsed_pred": parsed_pred
66 | })
67 |
68 | judge_dict, metric_dict = evaluate(exampels_to_eval)
69 | metric_dict.update({"num_example": len(exampels_to_eval)})
70 |
71 | evaluation_result[category] = metric_dict
72 |
73 | printable_results = {}
74 | # pdb.set_trace()
75 | # add domain Subject
76 | for domain, in_domain_cats in DOMAIN_CAT2SUB_CAT.items():
77 | in_domain_cat_results = {}
78 | for cat_name in in_domain_cats: # use the order in DOMAIN_CAT2SUB_CAT
79 | if cat_name in evaluation_result.keys():
80 | in_domain_cat_results[cat_name] = evaluation_result[cat_name]
81 | else:
82 | pass
83 | in_domain_ins_acc = calculate_ins_level_acc(in_domain_cat_results)
84 | in_domain_data_num = sum([cat_results['num_example'] for cat_results in in_domain_cat_results.values()])
85 | printable_results['Overall-' + domain] = {"num": int(in_domain_data_num),
86 | "acc": round(in_domain_ins_acc, 3)
87 | }
88 | # add sub category
89 | for cat_name, cat_results in in_domain_cat_results.items():
90 | printable_results[cat_name] = {"num": int(cat_results['num_example']),
91 | "acc": round(cat_results['acc'], 3)
92 | }
93 |
94 | # table.append(["-----------------------------", "-----", "----"])
95 | all_ins_acc = calculate_ins_level_acc(evaluation_result)
96 | printable_results['Overall'] = {
97 | "num": sum([cat_results['num_example'] for cat_results in evaluation_result.values()]),
98 | "acc": round(all_ins_acc, 3)
99 | }
100 |
101 | print(printable_results)
102 |
--------------------------------------------------------------------------------
/eval_mmmu/mmmu_response.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import random
4 |
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from datasets import load_dataset, concatenate_datasets
9 | from llava.model.builder import load_pretrained_model
10 | from llava.mm_utils import get_model_name_from_path
11 |
12 | from argparse import ArgumentParser
13 |
14 | from utils.data_utils import load_yaml, construct_prompt, save_json, process_single_sample, CAT_SHORT2LONG
15 | from utils.model_utils import call_llava_engine_df, llava_image_processor
16 | from utils.eval_utils import parse_multi_choice_response, parse_open_response
17 |
18 |
19 | def run_model(args, samples, model, call_model_engine_fn=None, tokenizer=None, processor=None):
20 | out_samples = dict()
21 | with torch.no_grad():
22 | for sample in tqdm(samples):
23 | response = call_model_engine_fn(args, sample, model, tokenizer, processor)
24 | #print(response)
25 | if sample['question_type'] == 'multiple-choice':
26 | pred_ans = parse_multi_choice_response(response, sample['all_choices'], sample['index2ans'])
27 | else: # open question
28 | pred_ans = response
29 | out_samples[sample['id']] = pred_ans
30 | return out_samples
31 |
32 | def set_seed(seed_value):
33 | """
34 | Set the seed for PyTorch (both CPU and CUDA), Python, and NumPy for reproducible results.
35 |
36 | :param seed_value: An integer value to be used as the seed.
37 | """
38 | torch.manual_seed(seed_value)
39 | if torch.cuda.is_available():
40 | torch.cuda.manual_seed(seed_value)
41 | torch.cuda.manual_seed_all(seed_value) # For multi-GPU setups
42 | random.seed(seed_value)
43 | np.random.seed(seed_value)
44 | torch.backends.cudnn.deterministic = True
45 | torch.backends.cudnn.benchmark = False
46 |
47 | def main():
48 | parser = ArgumentParser()
49 | parser.add_argument('--output_path', type=str, default='mmmu_eval_output.json',
50 | help='name of saved json')
51 | parser.add_argument('--config_path', type=str, default="configs/llava1.5.yaml")
52 | parser.add_argument('--data_path', type=str, default="MMMU/MMMU") # hf dataset path.
53 | parser.add_argument('--model_path', type=str, default="")
54 | parser.add_argument('--split', type=str, default='validation')
55 | parser.add_argument('--seed', type=int, default=42)
56 |
57 | args = parser.parse_args()
58 | device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
59 | set_seed(args.seed)
60 |
61 | print('llava_initializing...')
62 | processor = None
63 | call_model_engine = call_llava_engine_df
64 | vis_process_func = llava_image_processor
65 |
66 | # load config and process to one value
67 | args.config = load_yaml(args.config_path)
68 | for key, value in args.config.items():
69 | if key != 'eval_params' and type(value) == list:
70 | assert len(value) == 1, 'key {} has more than one value'.format(key)
71 | args.config[key] = value[0]
72 |
73 | # run for each subject
74 | sub_dataset_list = []
75 | for subject in CAT_SHORT2LONG.values():
76 | sub_dataset = load_dataset(args.data_path, subject, split=args.split)
77 | sub_dataset_list.append(sub_dataset)
78 |
79 | # merge all dataset
80 | dataset = concatenate_datasets(sub_dataset_list)
81 |
82 |
83 | # load model
84 | model_name = get_model_name_from_path(args.model_path)
85 |
86 | tokenizer, model, vis_processors, _ = load_pretrained_model(model_path=args.model_path,
87 | model_base=None,
88 | model_name=model_name)
89 |
90 | samples = []
91 | for sample in dataset:
92 | sample = process_single_sample(sample)
93 | #print(sample)
94 | sample = construct_prompt(sample, args.config)
95 | if sample['image']:
96 | sample['image'] = vis_process_func(sample['image'], vis_processors).to(device)
97 | samples.append(sample)
98 |
99 | ## run ex
100 | out_samples = run_model(args, samples, model, call_model_engine, tokenizer, processor)
101 |
102 | save_json(args.output_path, out_samples)
103 | # metric_dict.update({"num_example": len(out_samples)})
104 | # save_json(save_result_path, metric_dict)
105 |
106 |
107 | if __name__ == '__main__':
108 | main()
109 |
--------------------------------------------------------------------------------
/eval_mmmu/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for data load, save, and process (e.g., prompt construction)"""
2 |
3 | import os
4 | import json
5 | import yaml
6 | import re
7 |
8 |
9 | DOMAIN_CAT2SUB_CAT = {
10 | 'Art and Design': ['Art', 'Art_Theory', 'Design', 'Music'],
11 | 'Business': ['Accounting', 'Economics', 'Finance', 'Manage','Marketing'],
12 | 'Science': ['Biology', 'Chemistry', 'Geography', 'Math', 'Physics',],
13 | 'Health and Medicine': ['Basic_Medical_Science', 'Clinical_Medicine', 'Diagnostics_and_Laboratory_Medicine', 'Pharmacy', 'Public_Health'],
14 | 'Humanities and Social Science': ['History', 'Literature', 'Sociology', 'Psychology'],
15 | 'Tech and Engineering': ['Agriculture', 'Architecture_and_Engineering', 'Computer_Science', 'Electronics', 'Energy_and_Power', 'Materials', 'Mechanical_Engineering'],
16 | }
17 |
18 |
19 | CAT_SHORT2LONG = {
20 | 'acc': 'Accounting',
21 | 'agri': 'Agriculture',
22 | 'arch': 'Architecture_and_Engineering',
23 | 'art': 'Art',
24 | 'art_theory': 'Art_Theory',
25 | 'bas_med': 'Basic_Medical_Science',
26 | 'bio': 'Biology',
27 | 'chem': 'Chemistry',
28 | 'cli_med': 'Clinical_Medicine',
29 | 'cs': 'Computer_Science',
30 | 'design': 'Design',
31 | 'diag_med': 'Diagnostics_and_Laboratory_Medicine',
32 | 'econ': 'Economics',
33 | 'elec': 'Electronics',
34 | 'ep': 'Energy_and_Power',
35 | 'fin': 'Finance',
36 | 'geo': 'Geography',
37 | 'his': 'History',
38 | 'liter': 'Literature',
39 | 'manage': 'Manage',
40 | 'mark': 'Marketing',
41 | 'mate': 'Materials',
42 | 'math': 'Math',
43 | 'mech': 'Mechanical_Engineering',
44 | 'music': 'Music',
45 | 'phar': 'Pharmacy',
46 | 'phys': 'Physics',
47 | 'psy': 'Psychology',
48 | 'pub_health': 'Public_Health',
49 | 'socio': 'Sociology'
50 | }
51 |
52 | # DATA SAVING
53 | def save_json(filename, ds):
54 | with open(filename, 'w') as f:
55 | json.dump(ds, f, indent=4)
56 |
57 |
58 | def get_multi_choice_info(options):
59 | """
60 | Given the list of options for multiple choice question
61 | Return the index2ans and all_choices
62 | """
63 |
64 | start_chr = 'A'
65 | all_choices = []
66 | index2ans = {}
67 | for i, option in enumerate(options):
68 | index2ans[chr(ord(start_chr) + i)] = option
69 | all_choices.append(chr(ord(start_chr) + i))
70 |
71 | return index2ans, all_choices
72 |
73 | def load_yaml(file_path):
74 | with open(file_path, 'r') as stream:
75 | try:
76 | yaml_dict = yaml.safe_load(stream)
77 | except yaml.YAMLError as exc:
78 | print(exc)
79 |
80 | return yaml_dict
81 |
82 |
83 | def parse_img_path(text):
84 | matches = re.findall("
", text)
85 | return matches
86 |
87 | def process_single_sample(data):
88 | question = data['question']
89 | o_imgs_paths = []
90 | for option in data['options']:
91 | current_o_imgs_paths = parse_img_path(option)
92 | for img_path in current_o_imgs_paths:
93 | o_imgs_paths.append(img_path)
94 |
95 | if len(o_imgs_paths) > 1: # multiple images in options, used for random selection
96 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
97 | 'image': None, 'question_type': data['question_type']}
98 | else:
99 | return {'id': data['id'], 'question': question, 'options': data['options'], 'answer': data['answer'],
100 | 'image': data['image_1'], 'question_type': data['question_type']}
101 |
102 |
103 | # DATA SAVING
104 | def save_json(filename, ds):
105 | with open(filename, 'w') as f:
106 | json.dump(ds, f, indent=4)
107 |
108 | def save_jsonl(filename, data):
109 | """
110 | Save a dictionary of data to a JSON Lines file with the filename as key and caption as value.
111 |
112 | Args:
113 | filename (str): The path to the file where the data should be saved.
114 | data (dict): The dictionary containing the data to save where key is the image path and value is the caption.
115 | """
116 | with open(filename, 'w', encoding='utf-8') as f:
117 | for img_path, caption in data.items():
118 | # Extract the base filename without the extension
119 | base_filename = os.path.basename(img_path)
120 | # Create a JSON object with the filename as the key and caption as the value
121 | json_record = json.dumps({base_filename: caption}, ensure_ascii=False)
122 | # Write the JSON object to the file, one per line
123 | f.write(json_record + '\n')
124 |
125 | def save_args(args, path_dir):
126 | argsDict = args.__dict__
127 | with open(path_dir + 'setting.txt', 'w') as f:
128 | f.writelines('------------------ start ------------------' + '\n')
129 | for eachArg, value in argsDict.items():
130 | f.writelines(eachArg + ' : ' + str(value) + '\n')
131 | f.writelines('------------------- end -------------------')
132 |
133 |
134 |
135 | # DATA PROCESSING
136 | def construct_prompt(sample, config):
137 | question = sample['question']
138 | options = eval(sample['options'])
139 | example = ""
140 | if sample['question_type'] == 'multiple-choice':
141 | start_chr = 'A'
142 | prediction_range = []
143 | index2ans = {}
144 | for option in options:
145 | prediction_range.append(start_chr)
146 | example += f"({start_chr}) {option}\n"
147 | index2ans[start_chr] = option
148 | start_chr = chr(ord(start_chr) + 1)
149 | empty_prompt_sample_structure = config['multi_choice_example_format']
150 | empty_prompt = empty_prompt_sample_structure.format(question, example)
151 | res_dict = {}
152 | res_dict['index2ans'] = index2ans
153 | res_dict['correct_choice'] = sample['answer']
154 | res_dict['all_choices'] = prediction_range
155 | res_dict['empty_prompt'] = empty_prompt
156 | if config['task_instructions']:
157 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
158 | else:
159 | res_dict['final_input_prompt'] = empty_prompt
160 |
161 | res_dict['gt_content'] = options[ord(sample['answer'].upper()) - ord('A')]
162 | else:
163 | empty_prompt_sample_structure = config['short_ans_example_format']
164 | empty_prompt = empty_prompt_sample_structure.format(question)
165 | res_dict = {}
166 | res_dict['empty_prompt'] = empty_prompt
167 | if config['task_instructions']:
168 | res_dict['final_input_prompt'] = config['task_instructions'].strip() + '\n\n' + empty_prompt
169 | else:
170 | res_dict['final_input_prompt'] = empty_prompt
171 | res_dict['gt_content'] = sample['answer']
172 |
173 | res_dict.update(sample)
174 | return res_dict
--------------------------------------------------------------------------------
/eval_mmmu/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | """Response Parsing and Evaluation for various models"""
2 | from typing import Dict
3 |
4 | import re
5 | import random
6 | random.seed(42)
7 | import numpy as np
8 |
9 | # ----------- Process Multi-choice -------------
10 | def parse_multi_choice_response(response, all_choices, index2ans):
11 | """
12 | Parse the prediction from the generated response.
13 | Return the predicted index e.g., A, B, C, D.
14 | """
15 | for char in [',', '.', '!', '?', ';', ':', "'"]:
16 | response = response.strip(char)
17 | response = " " + response + " " # add space to avoid partial match
18 |
19 | index_ans = True
20 | ans_with_brack = False
21 | candidates = []
22 | for choice in all_choices: # e.g., (A) (B) (C) (D)
23 | if f'({choice})' in response:
24 | candidates.append(choice)
25 | ans_with_brack = True
26 |
27 | if len(candidates) == 0:
28 | for choice in all_choices: # e.g., A B C D
29 | if f' {choice} ' in response:
30 | candidates.append(choice)
31 |
32 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example
33 | if len(candidates) == 0 and len(response.split()) > 5:
34 | for index, ans in index2ans.items():
35 | if ans.lower() in response.lower():
36 | candidates.append(index)
37 | index_ans = False # it's content ans.
38 |
39 | if len(candidates) == 0: # still not get answer, randomly choose one.
40 | pred_index = random.choice(all_choices)
41 | elif len(candidates) > 1:
42 | start_indexes = []
43 | if index_ans:
44 | if ans_with_brack:
45 | for can in candidates:
46 | index = response.rfind(f'({can})')
47 | start_indexes.append(index) # -1 will be ignored anyway
48 | # start_indexes = [generated_response.index(f'({can})') for can in candidates]
49 | else:
50 | for can in candidates:
51 | index = response.rfind(f" {can} ")
52 | start_indexes.append(index)
53 | else:
54 | for can in candidates:
55 | index = response.lower().rfind(index2ans[can].lower())
56 | start_indexes.append(index)
57 | # get the last one
58 | pred_index = candidates[np.argmax(start_indexes)]
59 | else: # if only one candidate, use it.
60 | pred_index = candidates[0]
61 |
62 | return pred_index
63 |
64 | # ----------- Process Open -------------
65 | def check_is_number(string):
66 | """
67 | Check if the given string a number.
68 | """
69 | try:
70 | float(string.replace(',', ''))
71 | return True
72 | except ValueError:
73 | # check if there's comma inside
74 | return False
75 |
76 | def normalize_str(string):
77 | """
78 | Normalize the str to lower case and make them float numbers if possible.
79 | """
80 | # check if characters in the string
81 |
82 | # if number, numerize it.
83 | string = string.strip()
84 |
85 | is_number = check_is_number(string)
86 |
87 | if is_number:
88 | string = string.replace(',', '')
89 | string = float(string)
90 | # leave 2 decimal
91 | string = round(string, 2)
92 | return [string]
93 | else: # it's likely to be a string
94 | # lower it
95 | string = string.lower()
96 | if len(string) == 1:
97 | return [" " + string, string + " "] # avoid trivial matches
98 | return [string]
99 |
100 | def extract_numbers(string):
101 | """
102 | Exact all forms of numbers from a string with regex.
103 | """
104 | # Pattern for numbers with commas
105 | pattern_commas = r'-?\b\d{1,3}(?:,\d{3})+\b'
106 | # Pattern for scientific notation
107 | pattern_scientific = r'-?\d+(?:\.\d+)?[eE][+-]?\d+'
108 | # Pattern for simple numbers without commas
109 | pattern_simple = r'-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])'
110 |
111 | # Extract numbers with commas
112 | numbers_with_commas = re.findall(pattern_commas, string)
113 | # Extract numbers in scientific notation
114 | numbers_scientific = re.findall(pattern_scientific, string)
115 | # Extract simple numbers without commas
116 | numbers_simple = re.findall(pattern_simple, string)
117 |
118 | # Combine all extracted numbers
119 | all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
120 | return all_numbers
121 |
122 | def parse_open_response(response):
123 | """
124 | Parse the prediction from the generated response.
125 | Return a list of predicted strings or numbers.
126 | """
127 | # content = content.strip("\n").strip(".").strip(" ")
128 | def get_key_subresponses(response):
129 | key_responses = []
130 | response = response.strip().strip(".").lower()
131 | sub_responses = re.split(r'\.\s(?=[A-Z])|\n', response)
132 | indicators_of_keys = ['could be ', 'so ', 'is ',
133 | 'thus ', 'therefore ', 'final ', 'answer ', 'result ']
134 | key_responses = []
135 | for index, resp in enumerate(sub_responses):
136 | # if last one, accept it's an equation (the entire response can be just one sentence with equation)
137 | if index == len(sub_responses) - 1:
138 | indicators_of_keys.extend(['='])
139 | shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
140 | for indicator in indicators_of_keys:
141 | if indicator in resp:
142 | if not shortest_key_response:
143 | shortest_key_response = resp.split(indicator)[-1].strip()
144 | else:
145 | if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
146 | shortest_key_response = resp.split(indicator)[-1].strip()
147 | # key_responses.append(resp.split(indicator)[1].strip())
148 |
149 | if shortest_key_response:
150 | # and it's not trivial
151 | if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
152 | key_responses.append(shortest_key_response)
153 | if len(key_responses) == 0: # did not found any
154 | return [response]
155 | return key_responses
156 | # pdb.set_trace()
157 | key_responses = get_key_subresponses(response)
158 |
159 | pred_list = key_responses.copy() # keep the original string response
160 | for resp in key_responses:
161 | pred_list.extend(extract_numbers(resp))
162 |
163 | tmp_pred_list = []
164 | for i in range(len(pred_list)):
165 | tmp_pred_list.extend(normalize_str(pred_list[i]))
166 | pred_list = tmp_pred_list
167 |
168 | # remove duplicates
169 | pred_list = list(set(pred_list))
170 |
171 | return pred_list
172 |
173 | # ----------- Evaluation -------------
174 |
175 | def eval_multi_choice(gold_i, pred_i):
176 | """
177 | Evaluate a multiple choice instance.
178 | """
179 | correct = False
180 | # only they are exactly the same, we consider it as correct
181 | if isinstance(gold_i, list):
182 | for answer in gold_i:
183 | if answer == pred_i:
184 | correct = True
185 | break
186 | else: # gold_i is a string
187 | if gold_i == pred_i:
188 | correct = True
189 | return correct
190 |
191 | def eval_open(gold_i, pred_i):
192 | """
193 | Evaluate an open question instance
194 | """
195 | correct = False
196 | if isinstance(gold_i, list):
197 | # use float to avoid trivial matches
198 | norm_answers = []
199 | for answer in gold_i:
200 | norm_answers.extend(normalize_str(answer))
201 | else:
202 | norm_answers = normalize_str(gold_i)
203 | for pred in pred_i: # pred is already normalized in parse response phase
204 | if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
205 | for norm_ans in norm_answers:
206 | # only see if the string answer in the string pred
207 | if isinstance(norm_ans, str) and norm_ans in pred:
208 | if not correct:
209 | correct = True
210 | break
211 | else: # it's a float number
212 | if pred in norm_answers:
213 | if not correct:
214 | correct = True
215 | break
216 | return correct
217 |
218 | # ----------- Batch Evaluation -------------
219 | def evaluate(samples):
220 | """
221 | Batch evaluation for multiple choice and open questions.
222 | """
223 | pred_correct = 0
224 | judge_dict = dict()
225 | for sample in samples:
226 | gold_i = sample['answer']
227 | pred_i = sample['parsed_pred']
228 | if sample['question_type'] == 'multiple-choice':
229 | correct = eval_multi_choice(gold_i, pred_i)
230 | else: # open question
231 | correct = eval_open(gold_i, pred_i)
232 |
233 | if correct:
234 | judge_dict[sample['id']] = 'Correct'
235 | pred_correct += 1
236 | else:
237 | judge_dict[sample['id']] = 'Wrong'
238 |
239 | if len(samples) == 0:
240 | return {'acc': 0}
241 | return judge_dict, {'acc': pred_correct / len(samples)}
242 |
243 |
244 |
245 | # ----------- Calculate Accuracy -------------
246 | def calculate_ins_level_acc(results: Dict):
247 | """Calculate the instruction level accuracy for given Subject results"""
248 | acc = 0
249 | ins_num = 0
250 | for cat_results in results.values():
251 | acc += cat_results['acc'] * cat_results['num_example']
252 | ins_num += cat_results['num_example']
253 | if ins_num == 0:
254 | return 0
255 | return acc / ins_num
256 |
257 |
--------------------------------------------------------------------------------
/eval_mmmu/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from random import random
2 | import torch
3 |
4 | def call_llava_engine_df(args, sample, model, tokenizer=None, processor=None):
5 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
6 | from llava.conversation import conv_templates, SeparatorStyle
7 |
8 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
9 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
10 |
11 | def insert_separator(X, sep):
12 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
13 |
14 | input_ids = []
15 | offset = 0
16 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
17 | offset = 1
18 | input_ids.append(prompt_chunks[0][0])
19 |
20 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
21 | input_ids.extend(x[offset:])
22 |
23 | if return_tensors is not None:
24 | if return_tensors == 'pt':
25 | return torch.tensor(input_ids, dtype=torch.long)
26 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
27 | return input_ids
28 |
29 | def deal_with_prompt(input_text, mm_use_im_start_end):
30 | qs = input_text
31 | if mm_use_im_start_end:
32 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
33 | else:
34 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
35 | return qs
36 |
37 | prompt = sample['final_input_prompt']
38 | prompt = deal_with_prompt(prompt, model.config.mm_use_im_start_end)
39 | conv = conv_templates['vicuna_v1'].copy()
40 | conv.append_message(conv.roles[0], prompt)
41 | conv.append_message(conv.roles[1], None)
42 | prompt = conv.get_prompt()
43 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
44 | image = sample['image']
45 | if image is not None:
46 | output_ids = model.generate(
47 | input_ids,
48 | images=image.unsqueeze(0).half().cuda(),
49 | do_sample=True,
50 | temperature=1,
51 | top_p=None,
52 | num_beams=5,
53 | max_new_tokens=128,
54 | use_cache=True)
55 |
56 | input_token_len = input_ids.shape[1]
57 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
58 | if n_diff_input_output > 0:
59 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
60 | response = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
61 | else: # multiple images actually
62 | if sample['question_type'] == 'multiple-choice':
63 | all_choices = sample['all_choices']
64 | response = random.choice(all_choices)
65 | else:
66 | response = 'INVALID GENERATION FOR MULTIPLE IMAGE INPUTS'
67 |
68 | return response
69 |
70 |
71 | def llava_image_processor(raw_image, vis_processors=None):
72 | image_tensor = vis_processors.preprocess(raw_image, return_tensors='pt')['pixel_values'][0]
73 | return image_tensor
74 |
--------------------------------------------------------------------------------
/evaluation_mathvista/build_query.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # pids: 799, 681, 615
4 | shot_examples = [
5 | {
6 | "question": "How much money does Ruth need to buy a baking dish, a casserole dish, and an ice cream scoop? (Unit: $)",
7 | "caption": "The image shows a table with a variety of items on it, including a baking dish, ice cream scoop, casserole dish, and rolling pin. The text in the image says:\n\n```\nbaking dish\n$4.00\nice cream scoop\n$6.00\ncasserole dish\n$3.00\nrolling pin\n$4.00\n```",
8 | "ocr": "[([5, 3], 'baking dish'), ([177, 5], '$4.00'), ([7, 41], 'ice cream scoop'), ([177, 37], '$6.00'), ([9, 69], 'casserole dish'), ([177, 69], '$3.00'), ([5, 98], 'rolling pin'), ([177, 101], '$4.00')]",
9 | "solution": """
10 | Find the total cost of a baking dish, a casserole dish, and an ice cream scoop.\n\n$4.00 + $3.00 + $6.00 = $13.00\n\nRuth needs $13.00.
11 | """,
12 | "code": """
13 | baking_dish_price = 4.00
14 | casserole_dish_price = 3.00
15 | ice_cream_scoop_price = 6.00
16 |
17 | ans = baking_dish_price + casserole_dish_price + ice_cream_scoop_price
18 | print(ans)
19 | """
20 | },
21 |
22 | {
23 | "question": "What is the largest city in the nation where this plane is headquartered?",
24 | "choices": ['hong kong', 'osaka', 'shanghai', 'tokyo'],
25 | "caption": "The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says \"Japan. Endless Discovery\".",
26 | "solution": """
27 | The caption mentions that the text on the image says "Japan. Endless Discovery". This indicates that the plane is headquartered in Japan.
28 |
29 | Among the Japanese cities, Tokyo is the largest city.
30 |
31 | Thus, the answer is D (tokyo).
32 | """,
33 | "code": """
34 | def largest_city(caption, choices):
35 | countries_largest_cities = {
36 | 'Japan': 'tokyo',
37 | 'China': 'shanghai'
38 | }
39 |
40 | if "Japan" in caption:
41 | country = 'Japan'
42 | elif "China" in caption:
43 | country = 'China'
44 |
45 | for choice in choices:
46 | if choice == countries_largest_cities[country]:
47 | return choice
48 | return ""
49 |
50 | choices = ['hong kong', 'osaka', 'shanghai', 'tokyo']
51 | caption = "The image shows a large passenger jet parked on a tarmac at an airport. The jet is white with red trim and has a red tail. It is sitting on top of a tarmac next to a building. The jet is being loaded with passengers and cargo. The text on the image says 'Japan. Endless Discovery'."
52 |
53 | print(largest_city(caption, choices))
54 | """
55 | },
56 |
57 | {
58 | "question": "If two sides of a triangle measure 12 and 7, which of the following cannot be the perimeter of the triangle?",
59 | "choices": ['29', '34', '37', '38'],
60 | "caption": "The image shows a triangle with two sides labeled 7 and 12. The triangle is drawn on a white background. There is no text other than the labels.",
61 | "ocr": "[([70, 74], '7'), ([324, 74], '12')]",
62 | "solution": """
63 | To determine which of the given perimeters cannot be possible for the triangle, we apply the triangle inequality theorem. The sum of any two sides of a triangle must be greater than the third side.
64 |
65 | For the maximum possible value of the third side:
66 | 12 + 7 = 19
67 |
68 | The minimum possible value for the third side:
69 | 12 - 7 = 5
70 |
71 | The third side for each option:
72 | (A) 29 - 12 - 7 = 10 (valid)
73 | (B) 34 - 12 - 7 = 15 (valid)
74 | (C) 37 - 12 - 7 = 18 (valid)
75 | (D) 38 - 12 - 7 = 19 (invalid because it should be less than 19)
76 |
77 | Thus, the answer is D.
78 | """,
79 | "code": """
80 | def is_valid_triangle(a, b, perimeter):
81 | # Given a and b, find the third side
82 | third_side = perimeter - a - b
83 |
84 | # Check triangle inequality
85 | if (a + b > third_side) and (a + third_side > b) and (b + third_side > a):
86 | return True
87 | return False
88 |
89 | # Given sides
90 | a = 12
91 | b = 7
92 |
93 | # Given perimeters
94 | perimeters = [29, 34, 37, 38]
95 |
96 | # Check which perimeter is not valid
97 | for p in perimeters:
98 | if not is_valid_triangle(a, b, p):
99 | print(p)
100 | """,
101 |
102 | }
103 | ]
104 |
105 |
106 | def refine_caption(caption):
107 | if isinstance(caption, str):
108 | nonsense = ["Sure. ",
109 | "Sure, I can do that.",
110 | "Sorry, I can't help with images of people yet.",
111 | "I can't process this file.",
112 | "I'm unable to help you with that, as I'm only a language model and don't have the necessary information or abilities.",
113 | "I'm not programmed to assist with that.",
114 | "Please let me know if you have any other questions.",
115 | "I hope this is helpful!",
116 | "I hope this helps!"]
117 | for non in nonsense:
118 | caption = caption.replace(non, "").strip()
119 | caption = caption.replace(" ", " ").strip()
120 | else:
121 | caption = ""
122 | return caption
123 |
124 |
125 | def refine_ocr(ocr):
126 | """
127 | [ (
128 | [161, 39], [766, 39], [766, 120], [161, 120]],
129 | 'The spring force does',
130 | 0.912845069753024
131 | ),
132 | ]
133 | -->
134 | [ (
135 | [161, 39],
136 | 'The spring force does',
137 | ),
138 | ]
139 | """
140 | try:
141 | ocr = eval(ocr)
142 | if len(ocr) > 0:
143 | ocr = [([int(e[0][0][0]), int(e[0][0][1])], e[1]) for e in ocr]
144 | ocr = str(ocr)
145 | else:
146 | ocr = ""
147 | except:
148 | ocr = ""
149 | return ocr
150 |
151 |
152 | def create_one_query(problem, examples, shot_num, shot_type, use_caption, use_ocr):
153 |
154 |
155 | ### [1] Demo prompt
156 | if shot_num == 0:
157 | demo_prompt = ""
158 | else:
159 | demos = []
160 | shot_num = min(shot_num, len(examples))
161 | for example in examples[:shot_num]:
162 | prompt = ""
163 |
164 | # question
165 | prompt += f"Question: {example['question']}"
166 |
167 | # choices
168 | if "choices" in example:
169 | texts = ["Choices:"]
170 | for i, choice in enumerate(example['choices']):
171 | texts.append(f"({chr(ord('A')+i)}) {choice}")
172 | prompt += "\n" + "\n".join(texts)
173 |
174 | # caption
175 | if use_caption:
176 | caption = example['caption'] if 'caption' in example else ""
177 | if caption != "":
178 | prompt += "\n" + f"Image description: {caption}"
179 |
180 | # ocr
181 | if use_ocr:
182 | ocr = example['ocr'] if 'ocr' in example else ""
183 | if ocr != "":
184 | prompt += "\n" + f"Image detected text: {ocr}"
185 |
186 | # solution
187 | if shot_type == 'solution':
188 | solution = example['solution'].strip()
189 | prompt += "\n" + f"Solution: {solution}"
190 |
191 | # code
192 | if shot_type == 'code':
193 | code = example['code'].strip()
194 | prompt += "\n" + f"Python code: {code}"
195 |
196 | demos.append(prompt)
197 |
198 | demo_prompt = "\n\n".join(demos)
199 |
200 | ### [2] Test query
201 | # problem info
202 | question = problem['question']
203 | unit = problem['unit']
204 | choices = problem['choices']
205 | caption = problem['caption']
206 | ocr = problem['ocr']
207 | precision = problem['precision']
208 | question_type = problem['question_type']
209 | answer_type = problem['answer_type']
210 |
211 | # hint
212 | if shot_type == 'solution':
213 | if question_type == "multi_choice":
214 | assert answer_type == "text"
215 | hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end."
216 | else:
217 | assert answer_type in ["integer", "float", "list"]
218 | if answer_type == "integer":
219 | hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end."
220 |
221 | elif answer_type == "float" and precision == 1:
222 | hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end."
223 |
224 | elif answer_type == "float" and precision == 2:
225 | hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end."
226 |
227 | elif answer_type == "list":
228 | hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end."
229 | else:
230 | assert shot_type == 'code'
231 | hint_text = "Hint: Please generate a python code to solve the problem"
232 |
233 | # question
234 | question_text = f"Question: {question}"
235 | if unit:
236 | question_text += f" (Unit: {unit})"
237 |
238 | # choices
239 | if choices:
240 | # choices: (A) 1.2 (B) 1.3 (C) 1.4 (D) 1.5
241 | texts = ["Choices:"]
242 | for i, choice in enumerate(choices):
243 | texts.append(f"({chr(ord('A')+i)}) {choice}")
244 | choices_text = "\n".join(texts)
245 | else:
246 | choices_text = ""
247 |
248 | # caption
249 | caption_text = ""
250 | if use_caption and caption != "":
251 | caption_text = f"Image description: {caption}"
252 |
253 | # ocr
254 | ocr_text = ""
255 | if use_ocr and ocr != "":
256 | ocr_text = f"Image detected text: {ocr}"
257 |
258 | # prompt
259 | if shot_type == 'solution':
260 | prompt = "Solution: "
261 | else:
262 | assert shot_type == 'code'
263 | prompt = "Python code: "
264 |
265 | elements = [question_text, choices_text, caption_text, ocr_text, hint_text, prompt]
266 | test_query = "\n".join([e for e in elements if e != ""])
267 |
268 | ### [3] Final query
269 | query = demo_prompt + "\n\n" + test_query
270 | query = query.strip()
271 | return query
272 | #2': 'Question: what is the total volume of the measuring cup? (Unit: g)\n
273 | # Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.\n
274 | # Solution:',
275 | # 3': 'Question: △ABC的两内角平分线OB、OC相交于点O,若∠A=110°,则∠BOC=()\nChoices:\n(A) 135°\n(B) 140°\n(C) 145°\n(D) 150°\n
276 | # Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.\n
277 | # Solution:',
278 |
279 | def create_query_data(data, caption_data, ocr_data, args):
280 | query_data = {}
281 |
282 | for pid, problem in data.items():
283 | if pid in caption_data:
284 | caption = caption_data[pid]
285 | caption = refine_caption(caption)
286 | problem['caption'] = caption
287 | else:
288 | problem['caption'] = ""
289 |
290 | if pid in ocr_data:
291 | ocr = ocr_data[pid]
292 | ocr = refine_ocr(ocr)
293 | problem['ocr'] = ocr
294 | else:
295 | problem['ocr'] = []
296 |
297 | query = create_one_query(
298 | problem = problem,
299 | examples = shot_examples,
300 | shot_num = args.shot_num,
301 | shot_type = args.shot_type,
302 | use_caption = args.use_caption,
303 | use_ocr = args.use_ocr
304 | )
305 | query_data[pid] = query
306 |
307 | return query_data
308 |
309 |
310 | if __name__ == "__main__":
311 | for example in shot_examples:
312 | print("----------------------------------------")
313 | print("\nQuestion:", example['question'])
314 | if "choices" in example:
315 | print("\nChoices:", example['choices'])
316 | print("\nCaption:", example['caption'])
317 | if "ocr" in example:
318 | print("\nOCR:", example['ocr'])
319 | print("\nSolution:", example['solution'])
320 | print("\nCode:", example['code'])
321 |
322 | # execute code
323 | exec(example['code'])
324 |
325 |
326 |
--------------------------------------------------------------------------------
/evaluation_mathvista/calculate_score.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import argparse
4 | import pandas as pd
5 |
6 | # !pip install python-Levenshtein
7 | from Levenshtein import distance
8 |
9 | import sys
10 |
11 | sys.path.append('../')
12 | from utilities import *
13 |
14 |
15 | def get_most_similar(prediction, choices):
16 | """
17 | Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction
18 | """
19 | distances = [distance(prediction, choice) for choice in choices]
20 | ind = distances.index(min(distances))
21 | return choices[ind]
22 | # return min(choices, key=lambda choice: distance(prediction, choice))
23 |
24 |
25 | def normalize_extracted_answer(extraction, choices, question_type, answer_type, precision):
26 | """
27 | Normalize the extracted answer to match the answer type
28 | """
29 | if question_type == 'multi_choice':
30 | # make sure the extraction is a string
31 | if isinstance(extraction, str):
32 | extraction = extraction.strip()
33 | else:
34 | try:
35 | extraction = str(extraction)
36 | except:
37 | extraction = ""
38 |
39 | # extract "A" from "(A) text"
40 | letter = re.findall(r'\(([a-zA-Z])\)', extraction)
41 | if len(letter) > 0:
42 | extraction = letter[0].upper()
43 |
44 | options = [chr(ord('A') + i) for i in range(len(choices))]
45 |
46 | if extraction in options:
47 | # convert option letter to text, e.g. "A" -> "text"
48 | ind = options.index(extraction)
49 | extraction = choices[ind]
50 | else:
51 | # select the most similar option
52 | extraction = get_most_similar(extraction, choices)
53 | assert extraction in choices
54 |
55 | elif answer_type == 'integer':
56 | try:
57 | extraction = str(int(float(extraction)))
58 | except:
59 | extraction = None
60 |
61 | elif answer_type == 'float':
62 | try:
63 | extraction = str(round(float(extraction), precision))
64 | except:
65 | extraction = None
66 |
67 | elif answer_type == 'list':
68 | try:
69 | extraction = str(extraction)
70 | except:
71 | extraction = None
72 |
73 | return extraction
74 |
75 |
76 | def safe_equal(prediction, answer):
77 | """
78 | Check if the prediction is equal to the answer, even if they are of different types
79 | """
80 | try:
81 | if prediction == answer:
82 | return True
83 | return False
84 | except Exception as e:
85 | print(e)
86 | return False
87 |
88 |
89 | def get_acc_with_contion(res_pd, key, value):
90 | if key == 'skills':
91 | # if value in res_pd[key]:
92 | total_pd = res_pd[res_pd[key].apply(lambda x: value in x)]
93 | else:
94 | total_pd = res_pd[res_pd[key] == value]
95 |
96 | correct_pd = total_pd[total_pd['true_false'] == True]
97 | acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100)
98 | return len(correct_pd), len(total_pd), acc
99 |
100 |
101 | if __name__ == '__main__':
102 |
103 |
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs')
106 | parser.add_argument('--output_file', type=str, default='responses.json')
107 | parser.add_argument('--score_file', type=str, default='responses_score.json')
108 | parser.add_argument('--gt_file', type=str, default='./mathvista_data/testmini.json',
109 | help='ground truth file')
110 | parser.add_argument('--number', type=int, default=-1, help='number of problems to run')
111 | parser.add_argument('--rerun', default=True, help='rerun the evaluation')
112 | parser.add_argument('--caculate_gain', default=False, help='caculate the socre gains over random guess')
113 | parser.add_argument('--random_file', type=str, default='score_random_guess.json')
114 | args = parser.parse_args()
115 |
116 | # args
117 | output_file = os.path.join(args.output_dir, args.output_file)
118 |
119 |
120 |
121 | # read json
122 | print(f"Reading {output_file}...")
123 | results = read_json(output_file)
124 |
125 | # read ground truth
126 | print(f"Reading {args.gt_file}...")
127 | gts = read_json(args.gt_file)
128 |
129 | # full pids
130 | full_pids = list(results.keys())
131 | if args.number > 0:
132 | full_pids = full_pids[:min(args.number, len(full_pids))]
133 | print("Number of testing problems:", len(full_pids))
134 |
135 | ## [1] Evaluate if the prediction is true or false
136 | print("\nEvaluating the predictions...")
137 | update_json_flag = False
138 | for pid in full_pids:
139 | problem = results[pid]
140 | # print(problem)
141 |
142 | if args.rerun:
143 | if 'prediction' in problem:
144 | del problem['prediction']
145 | if 'true_false' in problem:
146 | del problem['true_false']
147 |
148 | choices = problem['choices']
149 | question_type = problem['question_type']
150 | answer_type = problem['answer_type']
151 | precision = problem['precision']
152 | extraction = problem['extraction']
153 |
154 | if 'answer' in problem:
155 | answer = problem['answer']
156 | else:
157 | answer = gts[pid]['answer']
158 | problem['answer'] = answer
159 |
160 | # normalize the extracted answer to match the answer type
161 | prediction = normalize_extracted_answer(extraction, choices, question_type, answer_type, precision)
162 |
163 | # verify the prediction is true or false
164 | true_false = safe_equal(prediction, answer)
165 |
166 | # update the problem
167 | if "true_false" not in problem:
168 | update_json_flag = True
169 |
170 | elif true_false != problem['true_false']:
171 | update_json_flag = True
172 |
173 | if "prediction" not in problem:
174 | update_json_flag = True
175 |
176 | elif prediction != problem['prediction']:
177 | update_json_flag = True
178 |
179 | problem['prediction'] = prediction
180 | problem['true_false'] = true_false
181 |
182 | # save the updated json
183 | if update_json_flag:
184 | print("\n!!!Some problems are updated.!!!")
185 | print(f"\nSaving {output_file}...")
186 | save_json(results, output_file)
187 |
188 | ## [2] Calculate the average accuracy
189 | total = len(full_pids)
190 | correct = 0
191 | for pid in full_pids:
192 | if results[pid]['true_false']:
193 | correct += 1
194 | accuracy = str(round(correct / total * 100, 2))
195 | print(f"\nCorrect: {correct}, Total: {total}, Accuracy: {accuracy}%")
196 |
197 | scores = {"average": {"accuracy": accuracy, "correct": correct, "total": total}}
198 |
199 | ## [3] Calculate the fine-grained accuracy scores
200 |
201 | # merge the 'metadata' attribute into the data
202 | for pid in results:
203 | results[pid].update(results[pid].pop('metadata'))
204 |
205 | # convert the data to a pandas DataFrame
206 | df = pd.DataFrame(results).T
207 |
208 | print(len(df))
209 | print("Number of test problems:", len(df))
210 | # assert len(df) == 1000 # Important!!!
211 |
212 | # asign the target keys for evaluation
213 | target_keys = ['question_type', 'answer_type', 'language', 'source', 'category', 'task', 'context', 'grade',
214 | 'skills']
215 |
216 | for key in target_keys:
217 | print(f"\nType: [{key}]")
218 | # get the unique values of the key
219 | if key == 'skills':
220 | # the value is a list
221 | values = []
222 | for i in range(len(df)):
223 | values += df[key][i]
224 | values = list(set(values))
225 | else:
226 | values = df[key].unique()
227 | # print(values)
228 |
229 | # calculate the accuracy for each value
230 | scores[key] = {}
231 | for value in values:
232 | correct, total, acc = get_acc_with_contion(df, key, value)
233 | if total > 0:
234 | print(f"[{value}]: {acc}% ({correct}/{total})")
235 | scores[key][value] = {"accuracy": acc, "correct": correct, "total": total}
236 |
237 | # sort the scores by accuracy
238 | scores[key] = dict(sorted(scores[key].items(), key=lambda item: float(item[1]['accuracy']), reverse=True))
239 |
240 | # save the scores
241 | scores_file = os.path.join(args.output_dir, args.score_file)
242 | print(f"\nSaving {scores_file}...")
243 | save_json(scores, scores_file)
244 | print("\nDone!")
245 |
246 | # [4] Calculate the score gains over random guess
247 | if args.caculate_gain:
248 | random_file = os.path.join(args.output_dir, args.random_file)
249 | random_scores = json.load(open(random_file))
250 |
251 | print("\nCalculating the score gains...")
252 | for key in scores:
253 | if key == 'average':
254 | gain = round(float(scores[key]['accuracy']) - float(random_scores[key]['accuracy']), 2)
255 | scores[key]['acc_gain'] = gain
256 | else:
257 | for sub_key in scores[key]:
258 | gain = round(
259 | float(scores[key][sub_key]['accuracy']) - float(random_scores[key][sub_key]['accuracy']), 2)
260 | scores[key][sub_key]['acc_gain'] = str(gain)
261 |
262 | # save the score gains
263 | print(f"\nSaving {scores_file}...")
264 | save_json(scores, scores_file)
265 | print("\nDone!")
--------------------------------------------------------------------------------
/evaluation_mathvista/ext_ans.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | # pids = 852, 104, 824, 506, 540
4 |
5 | demo_prompt = """
6 | Please read the following example. Then extract the answer from the model response and type it at the end of the prompt.
7 |
8 | Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end.
9 | Question: Which number is missing?
10 |
11 | Model response: The number missing in the sequence is 14.
12 |
13 | Extracted answer: 14
14 |
15 | Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end.
16 | Question: What is the fraction of females facing the camera?
17 |
18 | Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera.
19 |
20 | Extracted answer: 0.6
21 |
22 | Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end.
23 | Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $)
24 |
25 | Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy.
26 |
27 | Extracted answer: 1.45
28 |
29 | Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end.
30 | Question: Between which two years does the line graph saw its maximum peak?
31 |
32 | Model response: The line graph saw its maximum peak between 2007 and 2008.
33 |
34 | Extracted answer: [2007, 2008]
35 |
36 | Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end.
37 | Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5
38 |
39 | Model response: The correct answer is (B) 8/11.
40 |
41 | Extracted answer: B
42 | """
--------------------------------------------------------------------------------
/evaluation_mathvista/extract_answer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 | import argparse
5 |
6 | from tqdm import tqdm
7 |
8 | import sys
9 |
10 | sys.path.append('../')
11 | from utilities import *
12 |
13 | # OpenAI
14 | import openai
15 |
16 |
17 | api_key = "" #
18 | headers = {
19 | "Content-Type": "application/json",
20 | "Authorization": f"Bearer {api_key}"
21 | }
22 | # load demo prompt
23 | from ext_ans import demo_prompt
24 |
25 |
26 | def verify_extraction(extraction):
27 | extraction = extraction.strip()
28 | if extraction == "" or extraction == None:
29 | return False
30 | return True
31 |
32 |
33 | def create_test_prompt(demo_prompt, query, response): #few
34 | demo_prompt = demo_prompt.strip()
35 | test_prompt = f"{query}\n\n{response}"
36 | full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: "
37 | return full_prompt
38 |
39 |
40 | def extract_answer(response, problem, quick_extract=False, llm_engine):
41 | question_type = problem['question_type']
42 | answer_type = problem['answer_type']
43 | choices = problem['choices']
44 | query = problem['query']
45 |
46 | if response == "":
47 | return ""
48 |
49 | if question_type == 'multi_choice' and response in choices:
50 | return response
51 |
52 | if answer_type == "integer":
53 | try:
54 | extraction = int(response)
55 | return str(extraction)
56 | except:
57 | pass
58 |
59 | if answer_type == "float":
60 | try:
61 | extraction = str(float(response))
62 | return extraction
63 | except:
64 | pass
65 |
66 | # quick extraction
67 | if quick_extract:
68 | print("Quickly extracting answer...")
69 | try:
70 | result = re.search(r'The answer is "(.*)"\.', response)
71 | if result:
72 | extraction = result.group(1)
73 | return extraction
74 | except:
75 | pass
76 |
77 | # general extraction
78 | try:
79 | full_prompt = create_test_prompt(demo_prompt, query, response)
80 | extraction = get_chat_response_new(full_prompt, headers, model=llm_engine)
81 | return extraction
82 | except Exception as e:
83 | print(e)
84 | print(f"Error in extracting answer for {pid}")
85 |
86 | return response
87 |
88 |
89 | def extract_answer_quick(response, problem, quick_extract=False):
90 | question_type = problem['question_type']
91 | answer_type = problem['answer_type']
92 | choices = problem['choices']
93 | query = problem['query']
94 |
95 | if response == "":
96 | return ""
97 |
98 | if question_type == 'multi_choice' and response in choices:
99 | return response
100 |
101 | if answer_type == "integer":
102 | try:
103 | extraction = int(response)
104 | return str(extraction)
105 | except:
106 | pass
107 |
108 | if answer_type == "float":
109 | try:
110 | extraction = str(float(response))
111 | return extraction
112 | except:
113 | pass
114 |
115 | # quick extraction
116 | if quick_extract:
117 | print("Quickly extracting answer...")
118 | try:
119 | result = response.split('The answer is ')
120 | if result:
121 | #extraction = result.group(1)
122 | extraction = result[1]
123 | return extraction
124 | except:
125 | pass
126 |
127 |
128 | return response
129 |
130 |
131 | if __name__ == '__main__':
132 |
133 | parser = argparse.ArgumentParser()
134 | # input
135 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs')
136 | parser.add_argument('--output_file', type=str, default='responses.json')
137 | parser.add_argument('--response_label', type=str, default='response',
138 | help='response label for the input file')
139 | # model
140 | parser.add_argument('--llm_engine', type=str, default='gpt-3.5-turbo', help='llm engine',
141 | choices=['gpt-3.5-turbo', 'gpt-3.5', 'gpt-4', 'gpt-4-0314', 'gpt-4-0613'])
142 | parser.add_argument('--number', type=int, default=-1, help='number of problems to run')
143 | parser.add_argument('--quick_extract', default=False, help='use rules to extract answer for some problems')
144 | parser.add_argument('--rerun', default=False, help='rerun the answer extraction')
145 | # output
146 | parser.add_argument('--save_every', type=int, default=10, help='save every n problems')
147 | parser.add_argument('--output_label', type=str, default='', help='label for the output file')
148 | args = parser.parse_args()
149 |
150 | # args
151 | label = args.response_label
152 | result_file = os.path.join(args.output_dir, args.output_file)
153 |
154 | if args.output_label != '':
155 | output_file = result_file.replace('.json', f'_{args.output_label}.json')
156 | else:
157 | output_file = result_file
158 |
159 | # read results
160 | print(f"Reading {result_file}...")
161 | results = read_json(result_file)
162 |
163 | # full pids
164 | full_pids = list(results.keys())
165 | if args.number > 0:
166 | full_pids = full_pids[:min(args.number, len(full_pids))]
167 | print("Number of testing problems:", len(full_pids))
168 |
169 | # test pids
170 | if args.rerun:
171 | test_pids = full_pids
172 | else:
173 | test_pids = []
174 | for pid in full_pids:
175 | # print(pid)
176 | if 'extraction' not in results[pid] or not verify_extraction(results[pid]['extraction']):
177 | test_pids.append(pid)
178 |
179 | test_num = len(test_pids)
180 | print("Number of problems to run:", test_num)
181 | # print(test_pids)
182 |
183 | # tqdm, enumerate results
184 | for i, pid in enumerate(tqdm(test_pids)):
185 | problem = results[pid]
186 |
187 | assert label in problem
188 | response = problem[label]
189 |
190 | extraction = extract_answer(response, problem, args.quick_extract, args.llm_engine)
191 | results[pid]['extraction'] = extraction
192 |
193 | if i % args.save_every == 0 or i == test_num - 1:
194 | print(f"Saving results to {output_file}...")
195 | save_json(results, output_file)
196 | print(f"Results saved.")
197 |
--------------------------------------------------------------------------------
/evaluation_mathvista/response.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | import io
5 | import time
6 | import argparse
7 |
8 | from tqdm import tqdm
9 |
10 | import sys
11 |
12 | sys.path.append('../')
13 | from utilities import *
14 |
15 | from build_query import create_query_data
16 | from llava.model.builder import load_pretrained_model
17 | from llava.mm_utils import get_model_name_from_path
18 | from llava.eval.run_llava import eval_model, evalmodel
19 | from llava.utils import disable_torch_init
20 |
21 |
22 | def verify_response(response):
23 | if isinstance(response, str):
24 | response = response.strip()
25 | if response == "" or response == None:
26 | return False
27 | if "Response Error" in response:
28 | return False
29 | return True
30 |
31 |
32 | def evaluate_code(code_string):
33 | # execute_code_and_capture_output
34 | # Backup the original stdout
35 | old_stdout = sys.stdout
36 |
37 | # Redirect stdout to capture the output
38 | new_stdout = io.StringIO()
39 | sys.stdout = new_stdout
40 |
41 | # Try executing the code and capture any exception
42 | error = None
43 | try:
44 | exec(code_string)
45 | except Exception as e:
46 | error = e
47 |
48 | # Restore the original stdout
49 | sys.stdout = old_stdout
50 |
51 | # Get the captured output
52 | captured_output = new_stdout.getvalue()
53 | if isinstance(captured_output, str):
54 | captured_output = captured_output.strip()
55 |
56 | # Return the captured output or error
57 | return captured_output, error
58 |
59 |
60 | if __name__ == '__main__':
61 | parser = argparse.ArgumentParser()
62 | # input
63 | parser.add_argument('--data_dir', type=str, default='./mathvista_data')
64 | parser.add_argument('--input_file', type=str, default='testmini.json')
65 | # output
66 | parser.add_argument('--output_dir', type=str, default='./mathvista_outputs')
67 | parser.add_argument('--output_file', type=str, default='responses.json')
68 | # model
69 | parser.add_argument('--model_path', type=str, default='liuhaotian/llava-v1.5-13b', help='path of lora or full model')
70 | parser.add_argument('--model_base', type=str, default=None, help='liuhaotian/llava-v1.5-13b for lora, =None for full model')
71 | # query
72 | parser.add_argument('--query_file', type=str, default='query.json')
73 | parser.add_argument('--caption_file', type=str, default=None)
74 | parser.add_argument('--ocr_file', type=str, default=None)
75 | parser.add_argument('--shot_type', type=str, default='solution', help='shot type',
76 | choices=['solution', 'code'])
77 | parser.add_argument('--shot_num', type=int, default=0, help='number of shot examples')
78 | parser.add_argument('--use_caption', default=False, help='use caption data')
79 | parser.add_argument('--use_ocr', default=False, help='use ocr data')
80 | # other settings
81 | parser.add_argument('--rerun', default=False, help='rerun answer extraction for all problems')
82 | parser.add_argument('--debug', default=False, help='debug mode')
83 | args = parser.parse_args()
84 |
85 | # load data
86 | input_file = os.path.join(args.data_dir, args.input_file)
87 | print(f"Reading {input_file}...")
88 | data = read_json(input_file)
89 |
90 | # load or create query data
91 | if args.query_file:
92 | query_file = os.path.join(args.data_dir, args.query_file)
93 | if os.path.exists(query_file):
94 | print(f"Loading existing {query_file}...")
95 | query_data = read_json(query_file)
96 | else:
97 | print("\nCreating new query...")
98 | # load caption
99 | caption_data = {}
100 | if args.use_caption:
101 | caption_file = args.caption_file
102 | if os.path.exists(caption_file):
103 | print(f"Reading {caption_file}...")
104 | try:
105 | caption_data = read_json(caption_file)["texts"]
106 | print("Caption data loaded.")
107 | except:
108 | print("Caption data not found!! Please Check.")
109 | # load ocr
110 | ocr_data = {}
111 | if args.use_ocr:
112 | ocr_file = args.ocr_file
113 | if os.path.exists(ocr_file):
114 | print(f"Reading {ocr_file}...")
115 | try:
116 | ocr_data = read_json(ocr_file)["texts"]
117 | print("OCR data loaded.")
118 | except:
119 | print("OCR data not found!! Please Check.")
120 | # create query
121 | query_data = create_query_data(data, caption_data, ocr_data, args)
122 |
123 | #print(query_data)
124 |
125 | # output file
126 | os.makedirs(args.output_dir, exist_ok=True)
127 | output_file = os.path.join(args.output_dir, args.output_file)
128 |
129 | # load results
130 | if os.path.exists(output_file):
131 | print("\nResults already exist.")
132 | print(f"Reading {output_file}...")
133 | results = read_json(output_file)
134 | else:
135 | results = {}
136 |
137 | model_path = args.model_path
138 | model_base = args.model_base
139 |
140 | tokenizer, model, image_processor, context_len = load_pretrained_model(
141 | model_path=model_path,
142 | model_base=model_base,
143 | model_name=get_model_name_from_path(model_path)
144 | )
145 |
146 | disable_torch_init()
147 | model_name = get_model_name_from_path(model_path)
148 | ##
149 |
150 | # build final test pid list
151 | test_pids = list(data.keys())
152 | print("\nNumber of test problems in total:", len(test_pids))
153 |
154 | skip_pids = []
155 | if not args.rerun:
156 | print("\nRemoving problems with existing valid response...")
157 | for pid in test_pids:
158 | # print(f"Checking {pid}...")
159 | if pid in results and 'response' in results[pid]:
160 | response = results[pid]['response']
161 | if verify_response(response):
162 | # print(f"Valid response found for {pid}.")
163 | skip_pids.append(pid)
164 | else:
165 | print("\nRerun answer extraction for all problems...")
166 |
167 | test_pids = [pid for pid in test_pids if pid not in skip_pids]
168 | print("Number of test problems to run:", len(test_pids))
169 | # print(test_pids)
170 |
171 | # tqdm, enumerate results
172 | for _, pid in enumerate(tqdm(test_pids)):
173 | problem = data[pid]
174 | query = query_data[pid]
175 | image = problem['image']
176 | image_path = os.path.join(args.data_dir, image)
177 |
178 | if args.debug:
179 | print("--------------------------------------------------------------")
180 | print(f"\nGenerating response for {pid}...")
181 | try:
182 |
183 | args_llava = type('Args', (), {
184 | "model_path": model_path,
185 | "model_base": None,
186 | "model_name": get_model_name_from_path(model_path),
187 | "query": query,
188 | "conv_mode": None,
189 | "image_file": image_path,
190 | "sep": ",",
191 | "temperature": 0.2,
192 | "top_p": None,
193 | "num_beams": 1,
194 | "max_new_tokens": 512
195 | })()
196 | response = evalmodel(args_llava, model_name, tokenizer, model, image_processor, context_len)
197 | results[pid] = problem
198 | results[pid]['query'] = query
199 | if args.shot_type == 'solution':
200 | results[pid]['response'] = response
201 | else:
202 | output, error = evaluate_code(response)
203 | results[pid]['response'] = response
204 | results[pid]['execution'] = output
205 | results[pid]['error'] = str(error)
206 | if args.debug:
207 | print(f"\n#Query: \n{query}")
208 | print(f"\n#Response: \n{response}")
209 | except Exception as e:
210 | print(e)
211 | print(f"Error in extracting answer for {pid}")
212 | results[pid]['error'] = e
213 |
214 | try:
215 | print(f"Saving results to {output_file}...")
216 | save_json(results, output_file)
217 | print(f"Results saved.")
218 | except Exception as e:
219 | print(e)
220 | print(f"Error in saving {output_file}")
221 |
--------------------------------------------------------------------------------
/evaluation_mathvista/utilities.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import json
4 | import time
5 | import pickle
6 | import openai
7 | import re
8 |
9 | import requests
10 | from word2number import w2n
11 |
12 |
13 |
14 | def create_dir(output_dir):
15 | if not os.path.exists(output_dir):
16 | os.makedirs(output_dir)
17 |
18 |
19 | def read_csv(file):
20 | data = []
21 | with open(file, 'r') as f:
22 | for line in f:
23 | data.append(line.strip())
24 | return data
25 |
26 |
27 | def read_pandas_csv(csv_path):
28 | # read a pandas csv sheet
29 | import pandas as pd
30 | df = pd.read_csv(csv_path)
31 | return df
32 |
33 |
34 | def read_json(path):
35 | with open(path, 'r', encoding='utf-8') as f:
36 | return json.load(f)
37 |
38 |
39 | def read_jsonl(file):
40 | with open(file, 'r') as f:
41 | data = [json.loads(line) for line in f]
42 | return data
43 |
44 |
45 | def read_pickle(path):
46 | with open(path, 'rb') as f:
47 | return pickle.load(f)
48 |
49 |
50 | def save_json(data, path):
51 | with open(path, 'w') as f:
52 | json.dump(data, f, indent=4)
53 |
54 |
55 | def save_array_img(path, image):
56 | cv2.imwrite(path, image)
57 |
58 |
59 | def contains_digit(text):
60 | # check if text contains a digit
61 | if any(char.isdigit() for char in text):
62 | return True
63 | return False
64 |
65 |
66 | def contains_number_word(text):
67 | # check if text contains a number word
68 | ignore_words = ["a", "an", "point"]
69 | words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text
70 | for word in words:
71 | if word in ignore_words:
72 | continue
73 | try:
74 | w2n.word_to_num(word)
75 | return True # If the word can be converted to a number, return True
76 | except ValueError:
77 | continue # If the word can't be converted to a number, continue with the next word
78 |
79 | # check if text contains a digit
80 | if any(char.isdigit() for char in text):
81 | return True
82 |
83 | return False # If none of the words could be converted to a number, return False
84 |
85 |
86 | def contains_quantity_word(text, special_keep_words=[]):
87 | # check if text contains a quantity word
88 | quantity_words = ["most", "least", "fewest"
89 | "more", "less", "fewer",
90 | "largest", "smallest", "greatest",
91 | "larger", "smaller", "greater",
92 | "highest", "lowest", "higher", "lower",
93 | "increase", "decrease",
94 | "minimum", "maximum", "max", "min",
95 | "mean", "average", "median",
96 | "total", "sum", "add", "subtract",
97 | "difference", "quotient", "gap",
98 | "half", "double", "twice", "triple",
99 | "square", "cube", "root",
100 | "approximate", "approximation",
101 | "triangle", "rectangle", "circle", "square", "cube", "sphere", "cylinder", "cone", "pyramid",
102 | "multiply", "divide",
103 | "percentage", "percent", "ratio", "proportion", "fraction", "rate",
104 | ]
105 |
106 | quantity_words += special_keep_words # dataset specific words
107 |
108 | words = re.findall(r'\b\w+\b', text) # This regex pattern matches any word in the text
109 | if any(word in quantity_words for word in words):
110 | return True
111 |
112 | return False # If none of the words could be converted to a number, return False
113 |
114 |
115 | def is_bool_word(text):
116 | if text in ["Yes", "No", "True", "False",
117 | "yes", "no", "true", "false",
118 | "YES", "NO", "TRUE", "FALSE"]:
119 | return True
120 | return False
121 |
122 |
123 | def is_digit_string(text):
124 | # remove ".0000"
125 | text = text.strip()
126 | text = re.sub(r'\.0+$', '', text)
127 | try:
128 | int(text)
129 | return True
130 | except ValueError:
131 | return False
132 |
133 |
134 | def is_float_string(text):
135 | # text is a float string if it contains a "." and can be converted to a float
136 | if "." in text:
137 | try:
138 | float(text)
139 | return True
140 | except ValueError:
141 | return False
142 | return False
143 |
144 |
145 | def copy_image(image_path, output_image_path):
146 | from shutil import copyfile
147 | copyfile(image_path, output_image_path)
148 |
149 |
150 | def copy_dir(src_dir, dst_dir):
151 | from shutil import copytree
152 | # copy the source directory to the target directory
153 | copytree(src_dir, dst_dir)
154 |
155 |
156 | import PIL.Image as Image
157 |
158 |
159 | def get_image_size(img_path):
160 | img = Image.open(img_path)
161 | width, height = img.size
162 | return width, height
163 |
164 |
165 |
166 | def get_chat_response(promot, api_key, model="gpt-3.5-turbo", temperature=0, max_tokens=256, n=1, patience=10000000,
167 | sleep_time=0):
168 | print(api_key)
169 | messages = [
170 | {"role": "user", "content": promot},
171 | ]
172 | # print("I am here")
173 | while patience > 0:
174 | patience -= 1
175 | try:
176 | response = openai.ChatCompletion.create(model=model,
177 | messages=messages,
178 | api_key=api_key,
179 | temperature=temperature,
180 | max_tokens=max_tokens,
181 | n=n)
182 | if n == 1:
183 | prediction = response['choices'][0]['message']['content'].strip()
184 | if prediction != "" and prediction != None:
185 | return prediction
186 | else:
187 | prediction = [choice['message']['content'].strip() for choice in response['choices']]
188 | if prediction[0] != "" and prediction[0] != None:
189 | return prediction
190 |
191 | except Exception as e:
192 | if "Rate limit" not in str(e):
193 | print(e)
194 |
195 | if "Please reduce the length of the messages" in str(e):
196 | print("!!Reduce promot size")
197 | # reduce input prompt and keep the tail
198 | new_size = int(len(promot) * 0.9)
199 | new_start = len(promot) - new_size
200 | promot = promot[new_start:]
201 | messages = [
202 | {"role": "user", "content": promot},
203 | ]
204 |
205 | if sleep_time > 0:
206 | time.sleep(sleep_time)
207 | return ""
208 |
209 |
210 |
211 | def get_chat_response_new(promot, headers, model="gpt-3.5-turbo", temperature=0, max_tokens=256, n=1, patience=10000000,
212 | sleep_time=0):
213 | messages = [
214 | {"role": "user", "content": promot},
215 | ]
216 |
217 | while patience > 0:
218 | patience -= 1
219 | try:
220 |
221 | data = {"model": model, "messages": messages, "temperature":0.0}
222 |
223 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)
224 | response = response.json()
225 |
226 | if n == 1:
227 | prediction = response['choices'][0]['message']['content'].strip()
228 | if prediction != "" and prediction != None:
229 | return prediction
230 | else:
231 | prediction = [choice['message']['content'].strip() for choice in response['choices']]
232 | if prediction[0] != "" and prediction[0] != None:
233 | return prediction
234 |
235 | except Exception as e:
236 | if "Rate limit" not in str(e):
237 | print(e)
238 |
239 | if "Please reduce the length of the messages" in str(e):
240 | print("!!Reduce promot size")
241 | # reduce input prompt and keep the tail
242 | new_size = int(len(promot) * 0.9)
243 | new_start = len(promot) - new_size
244 | promot = promot[new_start:]
245 | messages = [
246 | {"role": "user", "content": promot},
247 | ]
248 |
249 | if sleep_time > 0:
250 | time.sleep(sleep_time)
251 | return ""
252 |
--------------------------------------------------------------------------------
/finetune_task.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | deepspeed llava/train/train_mem.py \
3 | --deepspeed ./scripts/zero3.json \
4 | --model_name_or_path liuhaotian/llava-v1.5-13b \
5 | --version v1 \
6 | --data_path ./train_samples_all_tuning.json \
7 | --image_folder ./data_images \
8 | --vision_tower openai/clip-vit-large-patch14-336 \
9 | --mm_projector_type mlp2x_gelu \
10 | --mm_vision_select_layer -2 \
11 | --mm_use_im_start_end False \
12 | --mm_use_im_patch_token False \
13 | --image_aspect_ratio pad \
14 | --group_by_modality_length True \
15 | --bf16 True \
16 | --output_dir ./checkpoints/llava-v1.5-13b-full-finetune \
17 | --num_train_epochs 2 \
18 | --per_device_train_batch_size 16 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 50000 \
24 | --save_total_limit 1 \
25 | --learning_rate 2e-5 \
26 | --weight_decay 0. \
27 | --warmup_ratio 0.03 \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/eval/__pycache__/run_llava.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/llava/eval/__pycache__/run_llava.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import (
5 | IMAGE_TOKEN_INDEX,
6 | DEFAULT_IMAGE_TOKEN,
7 | DEFAULT_IM_START_TOKEN,
8 | DEFAULT_IM_END_TOKEN,
9 | IMAGE_PLACEHOLDER,
10 | )
11 | from llava.conversation import conv_templates, SeparatorStyle
12 | from llava.model.builder import load_pretrained_model
13 | from llava.utils import disable_torch_init
14 | from llava.mm_utils import (
15 | process_images,
16 | tokenizer_image_token,
17 | get_model_name_from_path,
18 | KeywordsStoppingCriteria,
19 | )
20 |
21 | from PIL import Image
22 |
23 | import requests
24 | from PIL import Image
25 | from io import BytesIO
26 | import re
27 |
28 |
29 | def image_parser(args):
30 | out = args.image_file.split(args.sep)
31 | return out
32 |
33 |
34 | def load_image(image_file):
35 | if image_file.startswith("http") or image_file.startswith("https"):
36 | response = requests.get(image_file)
37 | image = Image.open(BytesIO(response.content)).convert("RGB")
38 | else:
39 | image = Image.open(image_file).convert("RGB")
40 | return image
41 |
42 |
43 | def load_images(image_files):
44 | out = []
45 | for image_file in image_files:
46 | image = load_image(image_file)
47 | out.append(image)
48 | return out
49 |
50 |
51 | def eval_model(args):
52 | # Model
53 | disable_torch_init()
54 |
55 | model_name = get_model_name_from_path(args.model_path)
56 | tokenizer, model, image_processor, context_len = load_pretrained_model(
57 | args.model_path, args.model_base, model_name
58 | )
59 |
60 | qs = args.query
61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
62 | if IMAGE_PLACEHOLDER in qs:
63 | if model.config.mm_use_im_start_end:
64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
65 | else:
66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
67 | else:
68 | if model.config.mm_use_im_start_end:
69 | qs = image_token_se + "\n" + qs
70 | else:
71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
72 |
73 | if "llama-2" in model_name.lower():
74 | conv_mode = "llava_llama_2"
75 | elif "v1" in model_name.lower():
76 | conv_mode = "llava_v1"
77 | elif "mpt" in model_name.lower():
78 | conv_mode = "mpt"
79 | else:
80 | conv_mode = "llava_v0"
81 |
82 | if args.conv_mode is not None and conv_mode != args.conv_mode:
83 | print(
84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
85 | conv_mode, args.conv_mode, args.conv_mode
86 | )
87 | )
88 | else:
89 | args.conv_mode = conv_mode
90 |
91 | conv = conv_templates[args.conv_mode].copy()
92 | conv.append_message(conv.roles[0], qs)
93 | conv.append_message(conv.roles[1], None)
94 | prompt = conv.get_prompt()
95 |
96 | image_files = image_parser(args)
97 | images = load_images(image_files)
98 | images_tensor = process_images(
99 | images,
100 | image_processor,
101 | model.config
102 | ).to(model.device, dtype=torch.float16)
103 |
104 | input_ids = (
105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
106 | .unsqueeze(0)
107 | .cuda()
108 | )
109 |
110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
111 | keywords = [stop_str]
112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
113 |
114 | with torch.inference_mode():
115 | output_ids = model.generate(
116 | input_ids,
117 | images=images_tensor,
118 | do_sample=True if args.temperature > 0 else False,
119 | temperature=args.temperature,
120 | top_p=args.top_p,
121 | num_beams=args.num_beams,
122 | max_new_tokens=args.max_new_tokens,
123 | use_cache=True,
124 | stopping_criteria=[stopping_criteria],
125 | )
126 |
127 | input_token_len = input_ids.shape[1]
128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
129 | if n_diff_input_output > 0:
130 | print(
131 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
132 | )
133 | outputs = tokenizer.batch_decode(
134 | output_ids[:, input_token_len:], skip_special_tokens=True
135 | )[0]
136 | outputs = outputs.strip()
137 | if outputs.endswith(stop_str):
138 | outputs = outputs[: -len(stop_str)]
139 | outputs = outputs.strip()
140 | print(outputs)
141 |
142 | def evalmodel(args, model_name, tokenizer, model, image_processor, context_len):
143 | # Model
144 | #disable_torch_init()
145 |
146 | # model_name = get_model_name_from_path(args.model_path)
147 | # tokenizer, model, image_processor, context_len = load_pretrained_model(
148 | # args.model_path, args.model_base, model_name
149 | # )
150 |
151 | qs = args.query
152 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
153 | if IMAGE_PLACEHOLDER in qs:
154 | if model.config.mm_use_im_start_end:
155 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
156 | else:
157 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
158 | else:
159 | if model.config.mm_use_im_start_end:
160 | qs = image_token_se + "\n" + qs
161 | else:
162 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
163 |
164 | if "llama-2" in model_name.lower():
165 | conv_mode = "llava_llama_2"
166 | elif "v1" in model_name.lower():
167 | conv_mode = "llava_v1"
168 | elif "mpt" in model_name.lower():
169 | conv_mode = "mpt"
170 | else:
171 | conv_mode = "llava_v0"
172 |
173 | if args.conv_mode is not None and conv_mode != args.conv_mode:
174 | print(
175 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
176 | conv_mode, args.conv_mode, args.conv_mode
177 | )
178 | )
179 | else:
180 | args.conv_mode = conv_mode
181 |
182 | conv = conv_templates[args.conv_mode].copy()
183 | conv.append_message(conv.roles[0], qs)
184 | conv.append_message(conv.roles[1], None)
185 | prompt = conv.get_prompt()
186 |
187 | image_files = image_parser(args)
188 | images = load_images(image_files)
189 | images_tensor = process_images(
190 | images,
191 | image_processor,
192 | model.config
193 | ).to(model.device, dtype=torch.float16)
194 |
195 | input_ids = (
196 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
197 | .unsqueeze(0)
198 | .cuda()
199 | )
200 |
201 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
202 | keywords = [stop_str]
203 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
204 |
205 | with torch.inference_mode():
206 | output_ids = model.generate(
207 | input_ids,
208 | images=images_tensor,
209 | do_sample=True if args.temperature > 0 else False,
210 | temperature=args.temperature,
211 | top_p=args.top_p,
212 | num_beams=args.num_beams,
213 | max_new_tokens=args.max_new_tokens,
214 | use_cache=True,
215 | stopping_criteria=[stopping_criteria],
216 | )
217 |
218 | input_token_len = input_ids.shape[1]
219 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
220 | if n_diff_input_output > 0:
221 | print(
222 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
223 | )
224 | outputs = tokenizer.batch_decode(
225 | output_ids[:, input_token_len:], skip_special_tokens=True
226 | )[0]
227 | outputs = outputs.strip()
228 | if outputs.endswith(stop_str):
229 | outputs = outputs[: -len(stop_str)]
230 | outputs = outputs.strip()
231 |
232 | #print(outputs)
233 | return outputs
234 |
235 |
236 | if __name__ == "__main__":
237 | parser = argparse.ArgumentParser()
238 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
239 | parser.add_argument("--model-base", type=str, default=None)
240 | parser.add_argument("--image-file", type=str, required=True)
241 | parser.add_argument("--query", type=str, required=True)
242 | parser.add_argument("--conv-mode", type=str, default=None)
243 | parser.add_argument("--sep", type=str, default=",")
244 | parser.add_argument("--temperature", type=float, default=0.2)
245 | parser.add_argument("--top_p", type=float, default=None)
246 | parser.add_argument("--num_beams", type=int, default=1)
247 | parser.add_argument("--max_new_tokens", type=int, default=512)
248 | args = parser.parse_args()
249 |
250 | eval_model(args)
251 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from llava.constants import IMAGE_TOKEN_INDEX
8 |
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 |
14 | def expand2square(pil_img, background_color):
15 | width, height = pil_img.size
16 | if width == height:
17 | return pil_img
18 | elif width > height:
19 | result = Image.new(pil_img.mode, (width, width), background_color)
20 | result.paste(pil_img, (0, (width - height) // 2))
21 | return result
22 | else:
23 | result = Image.new(pil_img.mode, (height, height), background_color)
24 | result.paste(pil_img, ((height - width) // 2, 0))
25 | return result
26 |
27 |
28 | def process_images(images, image_processor, model_cfg):
29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30 | new_images = []
31 | if image_aspect_ratio == 'pad':
32 | for image in images:
33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35 | new_images.append(image)
36 | else:
37 | return image_processor(images, return_tensors='pt')['pixel_values']
38 | if all(x.shape == new_images[0].shape for x in new_images):
39 | new_images = torch.stack(new_images, dim=0)
40 | return new_images
41 |
42 |
43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
45 |
46 | def insert_separator(X, sep):
47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48 |
49 | input_ids = []
50 | offset = 0
51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52 | offset = 1
53 | input_ids.append(prompt_chunks[0][0])
54 |
55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56 | input_ids.extend(x[offset:])
57 |
58 | if return_tensors is not None:
59 | if return_tensors == 'pt':
60 | return torch.tensor(input_ids, dtype=torch.long)
61 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
62 | return input_ids
63 |
64 |
65 | def get_model_name_from_path(model_path):
66 | model_path = model_path.strip("/")
67 | model_paths = model_path.split("/")
68 | if model_paths[-1].startswith('checkpoint-'):
69 | return model_paths[-2] + "_" + model_paths[-1]
70 | else:
71 | return model_paths[-1]
72 |
73 | class KeywordsStoppingCriteria(StoppingCriteria):
74 | def __init__(self, keywords, tokenizer, input_ids):
75 | self.keywords = keywords
76 | self.keyword_ids = []
77 | self.max_keyword_len = 0
78 | for keyword in keywords:
79 | cur_keyword_ids = tokenizer(keyword).input_ids
80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
81 | cur_keyword_ids = cur_keyword_ids[1:]
82 | if len(cur_keyword_ids) > self.max_keyword_len:
83 | self.max_keyword_len = len(cur_keyword_ids)
84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85 | self.tokenizer = tokenizer
86 | self.start_len = input_ids.shape[1]
87 |
88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
91 | for keyword_id in self.keyword_ids:
92 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
93 | return True
94 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
95 | for keyword in self.keywords:
96 | if keyword in outputs:
97 | return True
98 | return False
99 |
100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
101 | outputs = []
102 | for i in range(output_ids.shape[0]):
103 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
104 | return all(outputs)
105 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
3 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | import torch
22 | from llava.model import *
23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 |
25 |
26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
27 | kwargs = {"device_map": device_map, **kwargs}
28 |
29 | if device != "cuda":
30 | kwargs['device_map'] = {"": device}
31 |
32 | if load_8bit:
33 | kwargs['load_in_8bit'] = True
34 | elif load_4bit:
35 | kwargs['load_in_4bit'] = True
36 | kwargs['quantization_config'] = BitsAndBytesConfig(
37 | load_in_4bit=True,
38 | bnb_4bit_compute_dtype=torch.float16,
39 | bnb_4bit_use_double_quant=True,
40 | bnb_4bit_quant_type='nf4'
41 | )
42 | else:
43 | kwargs['torch_dtype'] = torch.float16
44 |
45 | if 'llava' in model_name.lower():
46 | # Load LLaVA model
47 | if 'lora' in model_name.lower() and model_base is None:
48 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
49 | if 'lora' in model_name.lower() and model_base is not None:
50 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
51 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
52 | print('Loading LLaVA from base model...')
53 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
54 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
55 | if model.lm_head.weight.shape[0] != token_num:
56 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
57 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
58 |
59 | print('Loading additional LLaVA weights...')
60 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
61 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
62 | else:
63 | # this is probably from HF Hub
64 | from huggingface_hub import hf_hub_download
65 | def load_from_hf(repo_id, filename, subfolder=None):
66 | cache_file = hf_hub_download(
67 | repo_id=repo_id,
68 | filename=filename,
69 | subfolder=subfolder)
70 | return torch.load(cache_file, map_location='cpu')
71 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
72 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
73 | if any(k.startswith('model.model.') for k in non_lora_trainables):
74 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
75 | model.load_state_dict(non_lora_trainables, strict=False)
76 |
77 | from peft import PeftModel
78 | print('Loading LoRA weights...')
79 | model = PeftModel.from_pretrained(model, model_path)
80 | print('Merging LoRA weights...')
81 | model = model.merge_and_unload()
82 | print('Model is loaded...')
83 | elif model_base is not None:
84 | # this may be mm projector only
85 | print('Loading LLaVA from base model...')
86 | if 'mpt' in model_name.lower():
87 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
88 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
89 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
90 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
91 | model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
92 | else:
93 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
94 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
95 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
96 |
97 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
98 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
99 | model.load_state_dict(mm_projector_weights, strict=False)
100 | else:
101 | if 'mpt' in model_name.lower():
102 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
103 | model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
104 | else:
105 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
106 | model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
107 | else:
108 | # Load language model
109 | if model_base is not None:
110 | # PEFT model
111 | from peft import PeftModel
112 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
113 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
114 | print(f"Loading LoRA weights from {model_path}")
115 | model = PeftModel.from_pretrained(model, model_path)
116 | print(f"Merging weights")
117 | model = model.merge_and_unload()
118 | print('Convert to FP16...')
119 | model.to(torch.float16)
120 | else:
121 | use_fast = False
122 | if 'mpt' in model_name.lower():
123 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
124 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
125 | else:
126 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
127 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
128 |
129 | image_processor = None
130 |
131 | if 'llava' in model_name.lower():
132 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
133 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
134 | if mm_use_im_patch_token:
135 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
136 | if mm_use_im_start_end:
137 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
138 | model.resize_token_embeddings(len(tokenizer))
139 |
140 | vision_tower = model.get_vision_tower()
141 | if not vision_tower.is_loaded:
142 | vision_tower.load_model()
143 | vision_tower.to(device=device, dtype=torch.float16)
144 | image_processor = vision_tower.image_processor
145 |
146 | if hasattr(model.config, "max_sequence_length"):
147 | context_len = model.config.max_sequence_length
148 | else:
149 | context_len = 2048
150 |
151 | return tokenizer, model, image_processor, context_len
152 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 |
26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
27 |
28 |
29 | class LlavaConfig(LlamaConfig):
30 | model_type = "llava"
31 |
32 |
33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
34 | config_class = LlavaConfig
35 |
36 | def __init__(self, config: LlamaConfig):
37 | super(LlavaLlamaModel, self).__init__(config)
38 |
39 |
40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaConfig
42 |
43 | def __init__(self, config):
44 | super(LlamaForCausalLM, self).__init__(config)
45 | self.model = LlavaLlamaModel(config)
46 | self.pretraining_tp = config.pretraining_tp
47 | self.vocab_size = config.vocab_size
48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.model
55 |
56 | def forward(
57 | self,
58 | input_ids: torch.LongTensor = None,
59 | attention_mask: Optional[torch.Tensor] = None,
60 | position_ids: Optional[torch.LongTensor] = None,
61 | past_key_values: Optional[List[torch.FloatTensor]] = None,
62 | inputs_embeds: Optional[torch.FloatTensor] = None,
63 | labels: Optional[torch.LongTensor] = None,
64 | use_cache: Optional[bool] = None,
65 | output_attentions: Optional[bool] = None,
66 | output_hidden_states: Optional[bool] = None,
67 | images: Optional[torch.FloatTensor] = None,
68 | return_dict: Optional[bool] = None,
69 | ) -> Union[Tuple, CausalLMOutputWithPast]:
70 |
71 | if inputs_embeds is None:
72 | (
73 | input_ids,
74 | position_ids,
75 | attention_mask,
76 | past_key_values,
77 | inputs_embeds,
78 | labels
79 | ) = self.prepare_inputs_labels_for_multimodal(
80 | input_ids,
81 | position_ids,
82 | attention_mask,
83 | past_key_values,
84 | labels,
85 | images
86 | )
87 |
88 | return super().forward(
89 | input_ids=input_ids,
90 | attention_mask=attention_mask,
91 | position_ids=position_ids,
92 | past_key_values=past_key_values,
93 | inputs_embeds=inputs_embeds,
94 | labels=labels,
95 | use_cache=use_cache,
96 | output_attentions=output_attentions,
97 | output_hidden_states=output_hidden_states,
98 | return_dict=return_dict
99 | )
100 |
101 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
102 | images = kwargs.pop("images", None)
103 | _inputs = super().prepare_inputs_for_generation(
104 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
105 | )
106 | if images is not None:
107 | _inputs['images'] = images
108 | return _inputs
109 |
110 | AutoConfig.register("llava", LlavaConfig)
111 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
112 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple
17 | import warnings
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | import math
22 |
23 | from transformers import AutoConfig, AutoModelForCausalLM
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 |
26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaMPTConfig(MPTConfig):
31 | model_type = "llava_mpt"
32 |
33 |
34 | class LlavaMPTModel(LlavaMetaModel, MPTModel):
35 | config_class = LlavaMPTConfig
36 |
37 | def __init__(self, config: MPTConfig):
38 | config.hidden_size = config.d_model
39 | super(LlavaMPTModel, self).__init__(config)
40 |
41 | def embed_tokens(self, x):
42 | return self.wte(x)
43 |
44 |
45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
46 | config_class = LlavaMPTConfig
47 | supports_gradient_checkpointing = True
48 |
49 | def __init__(self, config):
50 | super(MPTForCausalLM, self).__init__(config)
51 |
52 | if not config.tie_word_embeddings:
53 | raise ValueError('MPTForCausalLM only supports tied word embeddings')
54 | self.transformer = LlavaMPTModel(config)
55 | self.logit_scale = None
56 | if config.logit_scale is not None:
57 | logit_scale = config.logit_scale
58 | if isinstance(logit_scale, str):
59 | if logit_scale == 'inv_sqrt_d_model':
60 | logit_scale = 1 / math.sqrt(config.d_model)
61 | else:
62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63 | self.logit_scale = logit_scale
64 |
65 | def get_model(self):
66 | return self.transformer
67 |
68 | def _set_gradient_checkpointing(self, module, value=False):
69 | if isinstance(module, LlavaMPTModel):
70 | module.gradient_checkpointing = value
71 |
72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73 | return_dict = return_dict if return_dict is not None else self.config.return_dict
74 | use_cache = use_cache if use_cache is not None else self.config.use_cache
75 |
76 | input_ids, _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images)
77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80 | if self.logit_scale is not None:
81 | if self.logit_scale == 0:
82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83 | logits *= self.logit_scale
84 | loss = None
85 | if labels is not None:
86 | labels = torch.roll(labels, shifts=-1)
87 | labels[:, -1] = -100
88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90 |
91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92 | if inputs_embeds is not None:
93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94 | attention_mask = kwargs['attention_mask'].bool()
95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96 | raise NotImplementedError('MPT does not support generation with right padding.')
97 | if self.transformer.attn_uses_sequence_id and self.training:
98 | sequence_id = torch.zeros_like(input_ids[:1])
99 | else:
100 | sequence_id = None
101 | if past_key_values is not None:
102 | input_ids = input_ids[:, -1].unsqueeze(-1)
103 | if self.transformer.prefix_lm:
104 | prefix_mask = torch.ones_like(attention_mask)
105 | if kwargs.get('use_cache') == False:
106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107 | else:
108 | prefix_mask = None
109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110 |
111 |
112 | AutoConfig.register("llava_mpt", LlavaMPTConfig)
113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
114 |
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4 | NUM_SENTINEL_TOKENS: int = 100
5 |
6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7 | """Adds sentinel tokens and padding token (if missing).
8 |
9 | Expands the tokenizer vocabulary to include sentinel tokens
10 | used in mixture-of-denoiser tasks as well as a padding token.
11 |
12 | All added tokens are added as special tokens. No tokens are
13 | added if sentinel tokens and padding token already exist.
14 | """
15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17 | if tokenizer.pad_token is None:
18 | tokenizer.add_tokens('', special_tokens=True)
19 | tokenizer.pad_token = ''
20 | assert tokenizer.pad_token_id is not None
21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23 | tokenizer.sentinel_token_ids = _sentinel_token_ids
24 |
25 | class AutoTokenizerForMOD(AutoTokenizer):
26 | """AutoTokenizer + Adaptation for MOD.
27 |
28 | A simple wrapper around AutoTokenizer to make instantiating
29 | an MOD-adapted tokenizer a bit easier.
30 |
31 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
32 | a padding token, and a property to get the token ids of the
33 | sentinel tokens.
34 | """
35 |
36 | @classmethod
37 | def from_pretrained(cls, *args, **kwargs):
38 | """See `AutoTokenizer.from_pretrained` docstring."""
39 | tokenizer = super().from_pretrained(*args, **kwargs)
40 | adapt_tokenizer_for_denoising(tokenizer)
41 | return tokenizer
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, attn_weights, past_key_value)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/configuration_mpt.py:
--------------------------------------------------------------------------------
1 | """A HuggingFace-style model configuration."""
2 | from typing import Dict, Optional, Union
3 | from transformers import PretrainedConfig
4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6 |
7 | class MPTConfig(PretrainedConfig):
8 | model_type = 'mpt'
9 |
10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
11 | """The MPT configuration class.
12 |
13 | Args:
14 | d_model (int): The size of the embedding dimension of the model.
15 | n_heads (int): The number of attention heads.
16 | n_layers (int): The number of layers in the model.
17 | expansion_ratio (int): The ratio of the up/down scale in the MLP.
18 | max_seq_len (int): The maximum sequence length of the model.
19 | vocab_size (int): The size of the vocabulary.
20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual.
21 | emb_pdrop (float): The dropout probability for the embedding layer.
22 | learned_pos_emb (bool): Whether to use learned positional embeddings
23 | attn_config (Dict): A dictionary used to configure the model's attention module:
24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention
25 | attn_pdrop (float): The dropout probability for the attention layers.
26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'.
27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer.
28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to
29 | this value.
30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None,
31 | use the default scale of ``1/sqrt(d_keys)``.
32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an
33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix
34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention.
35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id.
36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates
37 | which sub-sequence each token belongs to.
38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
39 | alibi (bool): Whether to use the alibi bias instead of position embeddings.
40 | alibi_bias_max (int): The maximum value of the alibi bias.
41 | init_device (str): The device to use for parameter initialization.
42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value.
43 | no_bias (bool): Whether to use bias in all layers.
44 | verbose (int): The verbosity level. 0 is silent.
45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by.
46 | norm_type (str): choose type of norm to use
47 | multiquery_attention (bool): Whether to use multiquery attention implementation.
48 | use_cache (bool): Whether or not the model should return the last key/values attentions
49 | init_config (Dict): A dictionary used to configure the model initialization:
50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_',
51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or
52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch.
53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True.
54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer.
55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution
56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``.
57 | init_std (float): The standard deviation of the normal distribution used to initialize the model,
58 | if using the baseline_ parameter initialization scheme.
59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes.
60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes.
61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes.
62 | ---
63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options
64 | """
65 | self.d_model = d_model
66 | self.n_heads = n_heads
67 | self.n_layers = n_layers
68 | self.expansion_ratio = expansion_ratio
69 | self.max_seq_len = max_seq_len
70 | self.vocab_size = vocab_size
71 | self.resid_pdrop = resid_pdrop
72 | self.emb_pdrop = emb_pdrop
73 | self.learned_pos_emb = learned_pos_emb
74 | self.attn_config = attn_config
75 | self.init_device = init_device
76 | self.logit_scale = logit_scale
77 | self.no_bias = no_bias
78 | self.verbose = verbose
79 | self.embedding_fraction = embedding_fraction
80 | self.norm_type = norm_type
81 | self.use_cache = use_cache
82 | self.init_config = init_config
83 | if 'name' in kwargs:
84 | del kwargs['name']
85 | if 'loss_fn' in kwargs:
86 | del kwargs['loss_fn']
87 | super().__init__(**kwargs)
88 | self._validate_config()
89 |
90 | def _set_config_defaults(self, config, config_defaults):
91 | for (k, v) in config_defaults.items():
92 | if k not in config:
93 | config[k] = v
94 | return config
95 |
96 | def _validate_config(self):
97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
99 | if self.d_model % self.n_heads != 0:
100 | raise ValueError('d_model must be divisible by n_heads')
101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
108 | raise NotImplementedError('alibi only implemented with torch and triton attention.')
109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
115 | if self.init_config.get('name', None) is None:
116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
117 | if not self.learned_pos_emb and (not self.attn_config['alibi']):
118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/custom_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 | class SharedEmbedding(nn.Embedding):
7 |
8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9 | if unembed:
10 | return F.linear(input, self.weight)
11 | return super().forward(input)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import torch
3 | import torch.nn as nn
4 |
5 | @contextmanager
6 | def init_empty_weights(include_buffers: bool=False):
7 | """Meta initialization context manager.
8 |
9 | A context manager under which models are initialized with all parameters
10 | on the meta device, therefore creating an empty model. Useful when just
11 | initializing the model would blow the available RAM.
12 |
13 | Args:
14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15 | not to also put all buffers on the meta device while initializing.
16 |
17 | Example:
18 | ```python
19 | import torch.nn as nn
20 |
21 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
22 | with init_empty_weights():
23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24 | ```
25 |
26 |
27 |
28 | Any model created under this context manager has no weights. As such you can't do something like
29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30 |
31 |
32 | """
33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34 | yield f
35 |
36 | @contextmanager
37 | def init_on_device(device: torch.device, include_buffers: bool=False):
38 | """Device initialization context manager.
39 |
40 | A context manager under which models are initialized with all parameters
41 | on the specified device.
42 |
43 | Args:
44 | device (`torch.device`): Device to initialize all parameters on.
45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46 | not to also put all buffers on the meta device while initializing.
47 |
48 | Example:
49 | ```python
50 | import torch.nn as nn
51 |
52 | with init_on_device(device=torch.device("cuda")):
53 | tst = nn.Liner(100, 100) # on `cuda` device
54 | ```
55 | """
56 | old_register_parameter = nn.Module.register_parameter
57 | if include_buffers:
58 | old_register_buffer = nn.Module.register_buffer
59 |
60 | def register_empty_parameter(module, name, param):
61 | old_register_parameter(module, name, param)
62 | if param is not None:
63 | param_cls = type(module._parameters[name])
64 | kwargs = module._parameters[name].__dict__
65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66 |
67 | def register_empty_buffer(module, name, buffer):
68 | old_register_buffer(module, name, buffer)
69 | if buffer is not None:
70 | module._buffers[name] = module._buffers[name].to(device)
71 | if include_buffers:
72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73 | else:
74 | tensor_constructors_to_patch = {}
75 |
76 | def patch_tensor_constructor(fn):
77 |
78 | def wrapper(*args, **kwargs):
79 | kwargs['device'] = device
80 | return fn(*args, **kwargs)
81 | return wrapper
82 | try:
83 | nn.Module.register_parameter = register_empty_parameter
84 | if include_buffers:
85 | nn.Module.register_buffer = register_empty_buffer
86 | for torch_function_name in tensor_constructors_to_patch.keys():
87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88 | yield
89 | finally:
90 | nn.Module.register_parameter = old_register_parameter
91 | if include_buffers:
92 | nn.Module.register_buffer = old_register_buffer
93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94 | setattr(torch, torch_function_name, old_torch_function)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/param_init_fns.py:
--------------------------------------------------------------------------------
1 | import math
2 | import warnings
3 | from collections.abc import Sequence
4 | from functools import partial
5 | from typing import Optional, Tuple, Union
6 | import torch
7 | from torch import nn
8 | from .norm import NORM_CLASS_REGISTRY
9 |
10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
11 | del kwargs
12 | if verbose > 1:
13 | warnings.warn(f"Initializing network using module's reset_parameters attribute")
14 | if hasattr(module, 'reset_parameters'):
15 | module.reset_parameters()
16 |
17 | def fused_init_helper_(module: nn.Module, init_fn_):
18 | _fused = getattr(module, '_fused', None)
19 | if _fused is None:
20 | raise RuntimeError(f'Internal logic error')
21 | (dim, splits) = _fused
22 | splits = (0, *splits, module.weight.size(dim))
23 | for (s, e) in zip(splits[:-1], splits[1:]):
24 | slice_indices = [slice(None)] * module.weight.ndim
25 | slice_indices[dim] = slice(s, e)
26 | init_fn_(module.weight[slice_indices])
27 |
28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
29 | del kwargs
30 | if verbose > 1:
31 | warnings.warn(f'If model has bias parameters they are initialized to 0.')
32 | init_div_is_residual = init_div_is_residual
33 | if init_div_is_residual is False:
34 | div_is_residual = 1.0
35 | elif init_div_is_residual is True:
36 | div_is_residual = math.sqrt(2 * n_layers)
37 | elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
38 | div_is_residual = init_div_is_residual
39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
40 | div_is_residual = float(init_div_is_residual)
41 | else:
42 | div_is_residual = 1.0
43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
44 | if init_div_is_residual is not False:
45 | if verbose > 1:
46 | warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
47 | if isinstance(module, nn.Linear):
48 | if hasattr(module, '_fused'):
49 | fused_init_helper_(module, init_fn_)
50 | else:
51 | init_fn_(module.weight)
52 | if module.bias is not None:
53 | torch.nn.init.zeros_(module.bias)
54 | if init_div_is_residual is not False and getattr(module, '_is_residual', False):
55 | with torch.no_grad():
56 | module.weight.div_(div_is_residual)
57 | elif isinstance(module, nn.Embedding):
58 | if emb_init_std is not None:
59 | std = emb_init_std
60 | if std == 0:
61 | warnings.warn(f'Embedding layer initialized to 0.')
62 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
63 | if verbose > 1:
64 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
65 | elif emb_init_uniform_lim is not None:
66 | lim = emb_init_uniform_lim
67 | if isinstance(lim, Sequence):
68 | if len(lim) > 2:
69 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
70 | if lim[0] == lim[1]:
71 | warnings.warn(f'Embedding layer initialized to {lim[0]}.')
72 | else:
73 | if lim == 0:
74 | warnings.warn(f'Embedding layer initialized to 0.')
75 | lim = [-lim, lim]
76 | (a, b) = lim
77 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
78 | if verbose > 1:
79 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
80 | else:
81 | emb_init_fn_ = init_fn_
82 | emb_init_fn_(module.weight)
83 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
84 | if verbose > 1:
85 | warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
86 | if hasattr(module, 'weight') and module.weight is not None:
87 | torch.nn.init.ones_(module.weight)
88 | if hasattr(module, 'bias') and module.bias is not None:
89 | torch.nn.init.zeros_(module.bias)
90 | elif isinstance(module, nn.MultiheadAttention):
91 | if module._qkv_same_embed_dim:
92 | assert module.in_proj_weight is not None
93 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
94 | assert d_model is not None
95 | _d = d_model
96 | splits = (0, _d, 2 * _d, 3 * _d)
97 | for (s, e) in zip(splits[:-1], splits[1:]):
98 | init_fn_(module.in_proj_weight[s:e])
99 | else:
100 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
101 | assert module.in_proj_weight is None
102 | init_fn_(module.q_proj_weight)
103 | init_fn_(module.k_proj_weight)
104 | init_fn_(module.v_proj_weight)
105 | if module.in_proj_bias is not None:
106 | torch.nn.init.zeros_(module.in_proj_bias)
107 | if module.bias_k is not None:
108 | torch.nn.init.zeros_(module.bias_k)
109 | if module.bias_v is not None:
110 | torch.nn.init.zeros_(module.bias_v)
111 | init_fn_(module.out_proj.weight)
112 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
113 | with torch.no_grad():
114 | module.out_proj.weight.div_(div_is_residual)
115 | if module.out_proj.bias is not None:
116 | torch.nn.init.zeros_(module.out_proj.bias)
117 | else:
118 | for _ in module.parameters(recurse=False):
119 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
120 |
121 | def _normal_init_(std, mean=0.0):
122 | return partial(torch.nn.init.normal_, mean=mean, std=std)
123 |
124 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
125 | del kwargs
126 | init_fn_ = _normal_init_(std=std)
127 | if verbose > 1:
128 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
129 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
130 |
131 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
132 | del kwargs
133 | if init_std is None:
134 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
135 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
136 |
137 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
138 | del kwargs
139 | std = math.sqrt(2 / (5 * d_model))
140 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
141 |
142 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
143 | """From section 2.3.1 of GPT-NeoX-20B:
144 |
145 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
146 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151
147 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py
148 | """
149 | del kwargs
150 | residual_div = n_layers / math.sqrt(10)
151 | if verbose > 1:
152 | warnings.warn(f'setting init_div_is_residual to {residual_div}')
153 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
154 |
155 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
156 | del kwargs
157 | if verbose > 1:
158 | warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
159 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
160 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
161 |
162 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
163 | del kwargs
164 | if verbose > 1:
165 | warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
166 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
167 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
168 |
169 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
170 | del kwargs
171 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
172 | if verbose > 1:
173 | warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
174 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
175 |
176 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
177 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
178 | if verbose > 1:
179 | warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
180 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
181 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | else:
20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21 |
22 | def load_model(self):
23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25 | self.vision_tower.requires_grad_(False)
26 |
27 | self.is_loaded = True
28 |
29 | def feature_select(self, image_forward_outs):
30 | image_features = image_forward_outs.hidden_states[self.select_layer]
31 | if self.select_feature == 'patch':
32 | image_features = image_features[:, 1:]
33 | elif self.select_feature == 'cls_patch':
34 | image_features = image_features
35 | else:
36 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
37 | return image_features
38 |
39 | @torch.no_grad()
40 | def forward(self, images):
41 | if type(images) is list:
42 | image_features = []
43 | for image in images:
44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
46 | image_features.append(image_feature)
47 | else:
48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
50 |
51 | return image_features
52 |
53 | @property
54 | def dummy_feature(self):
55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56 |
57 | @property
58 | def dtype(self):
59 | return self.vision_tower.dtype
60 |
61 | @property
62 | def device(self):
63 | return self.vision_tower.device
64 |
65 | @property
66 | def config(self):
67 | if self.is_loaded:
68 | return self.vision_tower.config
69 | else:
70 | return self.cfg_only
71 |
72 | @property
73 | def hidden_size(self):
74 | return self.config.hidden_size
75 |
76 | @property
77 | def num_patches(self):
78 | return (self.config.image_size // self.config.patch_size) ** 2
79 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/llava/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import math
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import transformers.models.llama.modeling_llama
11 | from torch import nn
12 |
13 | try:
14 | import xformers.ops
15 | except ImportError:
16 | logging.error("xformers not found! Please install it before trying to use it.")
17 |
18 |
19 | def replace_llama_attn_with_xformers_attn():
20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21 |
22 |
23 | def xformers_forward(
24 | self,
25 | hidden_states: torch.Tensor,
26 | attention_mask: Optional[torch.Tensor] = None,
27 | position_ids: Optional[torch.LongTensor] = None,
28 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
29 | output_attentions: bool = False,
30 | use_cache: bool = False,
31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32 | # pylint: disable=duplicate-code
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 |
51 | kv_seq_len = key_states.shape[-2]
52 | if past_key_value is not None:
53 | kv_seq_len += past_key_value[0].shape[-2]
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | (
56 | query_states,
57 | key_states,
58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 |
63 | if past_key_value is not None:
64 | # reuse k, v, self_attention
65 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
66 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
67 |
68 | past_key_value = (key_states, value_states) if use_cache else None
69 |
70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
71 | if not output_attentions:
72 | query_states = query_states.transpose(1, 2)
73 | key_states = key_states.transpose(1, 2)
74 | value_states = value_states.transpose(1, 2)
75 |
76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
80 | attn_output = xformers.ops.memory_efficient_attention(
81 | query_states, key_states, value_states, attn_bias=None
82 | )
83 | else:
84 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
85 | attn_output = xformers.ops.memory_efficient_attention(
86 | query_states,
87 | key_states,
88 | value_states,
89 | attn_bias=xformers.ops.LowerTriangularMask(),
90 | )
91 | attn_weights = None
92 | else:
93 | attn_weights = torch.matmul(
94 | query_states, key_states.transpose(2, 3)
95 | ) / math.sqrt(self.head_dim)
96 |
97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98 | raise ValueError(
99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 | f" {attn_weights.size()}"
101 | )
102 |
103 | if attention_mask is not None:
104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 | raise ValueError(
106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 | )
108 | attn_weights = attn_weights + attention_mask
109 | attn_weights = torch.max(
110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 | )
112 |
113 | # upcast attention to fp32
114 | attn_weights = nn.functional.softmax(
115 | attn_weights, dim=-1, dtype=torch.float32
116 | ).to(query_states.dtype)
117 | attn_output = torch.matmul(attn_weights, value_states)
118 |
119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 | raise ValueError(
121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 | f" {attn_output.size()}"
123 | )
124 |
125 | attn_output = attn_output.transpose(1, 2)
126 |
127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 | attn_output = self.o_proj(attn_output)
129 | return attn_output, attn_weights, past_key_value
130 |
--------------------------------------------------------------------------------
/llava/train/llava_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from torch.utils.data import Sampler
5 |
6 | from transformers import Trainer
7 | from transformers.trainer import (
8 | is_sagemaker_mp_enabled,
9 | get_parameter_names,
10 | has_length,
11 | ALL_LAYERNORM_LAYERS,
12 | ShardedDDPOption,
13 | logger,
14 | )
15 | from typing import List, Optional
16 |
17 |
18 | def maybe_zero_3(param, ignore_status=False, name=None):
19 | from deepspeed import zero
20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
21 | if hasattr(param, "ds_id"):
22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
23 | if not ignore_status:
24 | print(name, 'no ignore status')
25 | with zero.GatheredParameters([param]):
26 | param = param.data.detach().cpu().clone()
27 | else:
28 | param = param.detach().cpu().clone()
29 | return param
30 |
31 |
32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
33 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
34 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
35 | return to_return
36 |
37 |
38 | def split_to_even_chunks(indices, lengths, num_chunks):
39 | """
40 | Split a list of indices into `chunks` chunks of roughly equal lengths.
41 | """
42 |
43 | if len(indices) % num_chunks != 0:
44 | return [indices[i::num_chunks] for i in range(num_chunks)]
45 |
46 | num_indices_per_chunk = len(indices) // num_chunks
47 |
48 | chunks = [[] for _ in range(num_chunks)]
49 | chunks_lengths = [0 for _ in range(num_chunks)]
50 | for index in indices:
51 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
52 | chunks[shortest_chunk].append(index)
53 | chunks_lengths[shortest_chunk] += lengths[index]
54 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
55 | chunks_lengths[shortest_chunk] = float("inf")
56 |
57 | return chunks
58 |
59 |
60 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
61 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
62 | assert all(l != 0 for l in lengths), "Should not have zero length."
63 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
64 | # all samples are in the same modality
65 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
66 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
67 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
68 |
69 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
70 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
71 | megabatch_size = world_size * batch_size
72 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
73 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
74 |
75 | last_mm = mm_megabatches[-1]
76 | last_lang = lang_megabatches[-1]
77 | additional_batch = last_mm + last_lang
78 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
79 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
80 | megabatches = [megabatches[i] for i in megabatch_indices]
81 |
82 | if len(additional_batch) > 0:
83 | megabatches.append(sorted(additional_batch))
84 |
85 | return [i for megabatch in megabatches for i in megabatch]
86 |
87 |
88 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
89 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
90 | indices = torch.randperm(len(lengths), generator=generator)
91 | megabatch_size = world_size * batch_size
92 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
93 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
94 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
95 |
96 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
97 |
98 |
99 | class LengthGroupedSampler(Sampler):
100 | r"""
101 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
102 | keeping a bit of randomness.
103 | """
104 |
105 | def __init__(
106 | self,
107 | batch_size: int,
108 | world_size: int,
109 | lengths: Optional[List[int]] = None,
110 | generator=None,
111 | group_by_modality: bool = False,
112 | ):
113 | if lengths is None:
114 | raise ValueError("Lengths must be provided.")
115 |
116 | self.batch_size = batch_size
117 | self.world_size = world_size
118 | self.lengths = lengths
119 | self.generator = generator
120 | self.group_by_modality = group_by_modality
121 |
122 | def __len__(self):
123 | return len(self.lengths)
124 |
125 | def __iter__(self):
126 | if self.group_by_modality:
127 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
128 | else:
129 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
130 | return iter(indices)
131 |
132 |
133 | class LLaVATrainer(Trainer):
134 |
135 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
136 | if self.train_dataset is None or not has_length(self.train_dataset):
137 | return None
138 |
139 | if self.args.group_by_modality_length:
140 | lengths = self.train_dataset.modality_lengths
141 | return LengthGroupedSampler(
142 | self.args.train_batch_size,
143 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
144 | lengths=lengths,
145 | group_by_modality=True,
146 | )
147 | else:
148 | return super()._get_train_sampler()
149 |
150 | def create_optimizer(self):
151 | """
152 | Setup the optimizer.
153 |
154 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
155 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
156 | """
157 | if is_sagemaker_mp_enabled():
158 | return super().create_optimizer()
159 | if self.sharded_ddp == ShardedDDPOption.SIMPLE:
160 | return super().create_optimizer()
161 |
162 | opt_model = self.model
163 |
164 | if self.optimizer is None:
165 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
166 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
167 | if self.args.mm_projector_lr is not None:
168 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
169 | optimizer_grouped_parameters = [
170 | {
171 | "params": [
172 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
173 | ],
174 | "weight_decay": self.args.weight_decay,
175 | },
176 | {
177 | "params": [
178 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
179 | ],
180 | "weight_decay": 0.0,
181 | },
182 | {
183 | "params": [
184 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
185 | ],
186 | "weight_decay": self.args.weight_decay,
187 | "lr": self.args.mm_projector_lr,
188 | },
189 | {
190 | "params": [
191 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
192 | ],
193 | "weight_decay": 0.0,
194 | "lr": self.args.mm_projector_lr,
195 | },
196 | ]
197 | else:
198 | optimizer_grouped_parameters = [
199 | {
200 | "params": [
201 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
202 | ],
203 | "weight_decay": self.args.weight_decay,
204 | },
205 | {
206 | "params": [
207 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
208 | ],
209 | "weight_decay": 0.0,
210 | },
211 | ]
212 |
213 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
214 |
215 | if self.sharded_ddp == ShardedDDPOption.SIMPLE:
216 | self.optimizer = OSS(
217 | params=optimizer_grouped_parameters,
218 | optim=optimizer_cls,
219 | **optimizer_kwargs,
220 | )
221 | else:
222 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
223 | if optimizer_cls.__name__ == "Adam8bit":
224 | import bitsandbytes
225 |
226 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
227 |
228 | skipped = 0
229 | for module in opt_model.modules():
230 | if isinstance(module, nn.Embedding):
231 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
232 | logger.info(f"skipped {module}: {skipped/2**20}M params")
233 | manager.register_module_override(module, "weight", {"optim_bits": 32})
234 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
235 | logger.info(f"skipped: {skipped/2**20}M params")
236 |
237 | return self.optimizer
238 |
239 | def _save_checkpoint(self, model, trial, metrics=None):
240 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
241 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
242 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
243 |
244 | run_dir = self._get_output_dir(trial=trial)
245 | output_dir = os.path.join(run_dir, checkpoint_folder)
246 |
247 | # Only save Adapter
248 | keys_to_match = ['mm_projector', 'vision_resampler']
249 | if getattr(self.args, "use_im_start_end", False):
250 | keys_to_match.extend(['embed_tokens', 'embed_in'])
251 |
252 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
253 |
254 | if self.args.local_rank == 0 or self.args.local_rank == -1:
255 | self.model.config.save_pretrained(output_dir)
256 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
257 | else:
258 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
259 |
260 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
261 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
262 | pass
263 | else:
264 | super(LLaVATrainer, self)._save(output_dir, state_dict)
265 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from llava.train.llama_xformers_attn_monkey_patch import (
5 | replace_llama_attn_with_xformers_attn,
6 | )
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/pipeline.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.1.3"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.0.1", "torchvision==0.15.2",
17 | "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
19 | "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20 | "gradio==3.35.2", "gradio_client==0.2.9",
21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23 | ]
24 |
25 | [project.optional-dependencies]
26 | train = ["deepspeed==0.9.5", "ninja", "wandb"]
27 |
28 | [project.urls]
29 | "Homepage" = "https://llava-vl.github.io"
30 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
31 |
32 | [tool.setuptools.packages.find]
33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
34 |
35 | [tool.wheel]
36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
37 |
--------------------------------------------------------------------------------
/scripts/extract_mm_projector.py:
--------------------------------------------------------------------------------
1 | """
2 | This is just a utility that I use to extract the projector for quantized models.
3 | It is NOT necessary at all to train, or run inference/serve demos.
4 | Use this script ONLY if you fully understand its implications.
5 | """
6 |
7 |
8 | import os
9 | import argparse
10 | import torch
11 | import json
12 | from collections import defaultdict
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17 | parser.add_argument('--model-path', type=str, help='model folder')
18 | parser.add_argument('--output', type=str, help='output file')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | if __name__ == '__main__':
24 | args = parse_args()
25 |
26 | keys_to_match = ['mm_projector']
27 | ckpt_to_key = defaultdict(list)
28 | try:
29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30 | for k, v in model_indices['weight_map'].items():
31 | if any(key_match in k for key_match in keys_to_match):
32 | ckpt_to_key[v].append(k)
33 | except FileNotFoundError:
34 | # Smaller models or model checkpoints saved by DeepSpeed.
35 | v = 'pytorch_model.bin'
36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37 | if any(key_match in k for key_match in keys_to_match):
38 | ckpt_to_key[v].append(k)
39 |
40 | loaded_weights = {}
41 |
42 | for ckpt_name, weight_keys in ckpt_to_key.items():
43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44 | for k in weight_keys:
45 | loaded_weights[k] = ckpt[k]
46 |
47 | torch.save(loaded_weights, args.output)
48 |
--------------------------------------------------------------------------------
/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from llava.model.builder import load_pretrained_model
3 | from llava.mm_utils import get_model_name_from_path
4 |
5 |
6 | def merge_lora(args):
7 | model_name = get_model_name_from_path(args.model_path)
8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9 |
10 | model.save_pretrained(args.save_model_path)
11 | tokenizer.save_pretrained(args.save_model_path)
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model-path", type=str, required=True)
17 | parser.add_argument("--model-base", type=str, required=True)
18 | parser.add_argument("--save-model-path", type=str, required=True)
19 |
20 | args = parser.parse_args()
21 |
22 | merge_lora(args)
23 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/train_samples_all_tuning.json.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HZQ950419/Math-LLaVA/3480ba7d2305de2c6f76941b091c2925641fc751/train_samples_all_tuning.json.zip
--------------------------------------------------------------------------------