├── .gitignore ├── Blank.jpg ├── LICENSE.md ├── README.md ├── data_samples ├── image_caption_pairs.json └── images │ ├── 0ae1b33f14e0aff353de_300.jpg │ ├── 214320-6696759ae28745f696a6830cb67c954d.jpg │ ├── 34393_1719273861489_1227549162_31979863_8203919_n-410x307.jpg │ ├── 4484931-a3606015e6614481afc6a433fe49387d.jpg │ ├── 51178860_creamy-slaw_1x1.jpg │ ├── 5241179-01bbc4f291a3480c8e04c0616a000011.jpg │ ├── 942578-80eff85811684394bc64592dee071d56.jpg │ ├── 970916-518557352d32499fb9d66ba00cc35963.jpg │ ├── Chocolate-Covered-Strawberry-Cheesecake_WD0156-S-29da911d0e1a414b9e1e7b2c800f6bde.jpg │ ├── Pumpkin-Waffles-with-Bacon-Maple-Butter-by-Gerry-Speirs-410x273.jpeg │ ├── Serrano-Spinach-Stuffed-Tomatoes-410x316.jpg │ ├── Small-Plates-No-Boil-Baked-Pasta-inset-in-pot-11012019.jpg │ ├── __opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__20120513-206173-spicy-brown-mustard-step-2-37f1f2516f214b6dbae208bbeef10380.jpg │ ├── classic-beef-stew-dutch-oven-recipe-FT-BLOG1019-c50074af5c1e4c3e865bf0a4d2c36266.jpg │ ├── garlicky-green-beans1-410x329.jpg │ ├── picCYxfUw.jpg │ ├── picIfFx2r.jpg │ ├── picJbKQVD.jpg │ ├── picKiyZk5.jpg │ ├── picVapP1X.jpg │ ├── picbgkmgN.jpg │ ├── pumpkinspice-cake-FG2-410x273.jpg │ ├── tWEWzgEySPijBjppP0bw-IMG_20140420_121943178.jpg │ └── tuna9-410x274.jpg ├── docs ├── Evaluation.md ├── Post_Train.md └── Synthesis.md ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ └── llava_mpt.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── process ├── read_compre_pt.py └── syn_utils.py ├── pyproject.toml ├── scripts ├── dataset_info.json ├── post_train_mllm.sh ├── tune_synthesizer.sh ├── zero2.json ├── zero3.json └── zero3_offload.json └── vllm_inference ├── eval_predictions.py ├── format_post_train_data.py ├── inference.py ├── merge_predictions.py ├── run_inference.sh ├── run_synthesis.sh └── utils ├── cache_util.py ├── consistency_filter_prompt.txt ├── conversation.py ├── food101_name_to_label_map.json ├── foodSeg103_id2label.json ├── llava_med ├── evaluate_metrics.py ├── glossary.py └── utils.py ├── metric.py ├── nutrition5k_ingredients.py └── task.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | 11 | # Data 12 | data 13 | 14 | # Editor 15 | .idea 16 | *.swp 17 | 18 | # Other 19 | .DS_Store 20 | wandb 21 | output 22 | eval_results 23 | *dummy* -------------------------------------------------------------------------------- /Blank.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/Blank.jpg -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | LICENSE AGREEMENT 2 | Last revision: Sep, 2023 3 | You are granted the right to use the code and/or Database under the following terms, as enlisted in this document (“Beijing Institute for General Artificial Intelligence BIGAI License Agreement”): 4 | · The code and/or data is strictly for non-commercial academic research only. 5 | · Any commercial use of the code or data requires prior contact and negotiation of licensing fees with the original authors or Beijing Institute for General Artificial Intelligence (BIGAI). 6 | · Any new access to the code and/or data shall be established through this form or the official method of distributing the code and/or data. The code and/or data may not be redistributed, in whole or part, or in any format without written prior permission. A reference to the code and/or data or this License Agreement must be made if you publish information. 7 | · The code and/or data is provided as is. No liability or responsibility assumed for the authors. 8 | · The right to revise this License Agreement, in whole or part, at any time without prior notice is reserved by the authors. 9 | · You warrant that you have the authorization to enter into this License Agreement. 10 | · You comply with the terms enforced by the corporates whose products were used in collecting the code and/or data. The terms unanimously enforce, including but not limited to, restricting the use of the code and/or data to non-commercial academic research. 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adapting Multimodal Large Language Models to Domains via Post-Training 2 | 3 | This repository provides models, code and data of our paper: [On Domain-Specific Post-Training for Multimodal Large Language Models](https://arxiv.org/abs/2411.19930). 4 | 5 | We investigate domain adaptation of MLLMs through post-training, focusing on data synthesis, training pipelines, and task evaluation. 6 | **(1) Data Synthesis**: Using open-source models, we develop a visual instruction synthesizer that effectively generates diverse visual instruction tasks from domain-specific image-caption pairs. **Our synthetic tasks surpass those generated by manual rules, GPT-4, and GPT-4V in enhancing the domain-specific performance of MLLMs.** 7 | **(2) Training Pipeline**: While the two-stage training--initially on image-caption pairs followed by visual instruction tasks--is commonly adopted for developing general MLLMs, we apply a single-stage training pipeline to enhance task diversity for domain-specific post-training. 8 | **(3) Task Evaluation**: We conduct experiments in two domains, biomedicine and food, by post-training MLLMs of different sources and scales (e.g., Qwen2-VL-2B, LLaVA-v1.6-8B, Llama-3.2-11B), and then evaluating MLLM performance on various domain-specific tasks. 9 | 10 | 11 | *********************** *Updates* ************************* 12 | - [2024/1/4] Updated **ALL** models, code and data to reproduce our results 13 | - [2024/1/3] Released the [post-training guide](docs/Post_Train.md) 14 | - [2024/12/16] Released the [data synthesis guide](docs/Synthesis.md) 15 | - [2024/12/13] Released the [evaluation guide](docs/Evaluation.md) 16 | - [2024/11/29] Released our paper 17 | 18 | ## Resources 19 | #### Domain-Specific Training Data 20 | - [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) 21 | - [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) 22 | 23 | 24 | #### Domain-Specific Evaluation Benchmark 25 | 26 | - [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) 27 | - [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) 28 | 29 | #### Domain-Specific Models 30 | | Model | Repo ID in HF 🤗 | Domain | Base Model | Training Data | Evaluation Benchmark | 31 | |:----------------------------------------------------------------------------|:--------------------------------------------|:--------------|:-------------------------|:------------------------------------------------------------------------------------------------|-----------------------| 32 | | [AdaMLLM-med-2B](https://huggingface.co/AdaptLLM/biomed-Qwen2-VL-2B-Instruct) | AdaptLLM/biomed-Qwen2-VL-2B-Instruct | Biomedicine | Qwen2-VL-2B-Instruct | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 33 | | [AdaMLLM-food-2B](https://huggingface.co/AdaptLLM/food-Qwen2-VL-2B-Instruct) | AdaptLLM/food-Qwen2-VL-2B-Instruct | Food | Qwen2-VL-2B-Instruct | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 34 | | [AdaMLLM-med-8B](https://huggingface.co/AdaptLLM/biomed-LLaVA-NeXT-Llama3-8B) | AdaptLLM/biomed-LLaVA-NeXT-Llama3-8B | Biomedicine | open-llava-next-llama3-8b | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 35 | | [AdaMLLM-food-8B](https://huggingface.co/AdaptLLM/food-LLaVA-NeXT-Llama3-8B) |AdaptLLM/food-LLaVA-NeXT-Llama3-8B | Food | open-llava-next-llama3-8b | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 36 | | [AdaMLLM-med-11B](https://huggingface.co/AdaptLLM/biomed-Llama-3.2-11B-Vision-Instruct) | AdaptLLM/biomed-Llama-3.2-11B-Vision-Instruct | Biomedicine | Llama-3.2-11B-Vision-Instruct | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 37 | | [AdaMLLM-food-11B](https://huggingface.co/AdaptLLM/food-Llama-3.2-11B-Vision-Instruct) | AdaptLLM/food-Llama-3.2-11B-Vision-Instruct | Food | Llama-3.2-11B-Vision-Instruct | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 38 | 39 | 40 | ## Setup 41 | We create two separate conda environments. 42 | 43 | #### Env 1: To fine-tune the visual instruction synthesizer and post-train LLaVA 44 | 45 | 1. Clone this repo: 46 | ```bash 47 | git clone https://github.com/bigai-ai/QA-Synthesizer.git 48 | cd QA-Synthesizer 49 | ``` 50 | 51 | 2. Install the package: 52 | ```bash 53 | conda create -n adamllm python=3.10 -y 54 | conda activate adamllm 55 | pip install --upgrade pip 56 | pip install -e . 57 | ``` 58 | 59 | 3. Install additional packages for training: 60 | ```bash 61 | pip install -e ".[train]" 62 | pip install flash-attn --no-build-isolation 63 | conda deactivate 64 | ``` 65 | 66 | #### Env 2: To synthesize visual instruction tasks and evaluate models on domain-specific tasks 67 | Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source). 68 | 69 | As recommended in the official vLLM documentation, install vLLM in a **fresh new** conda environment: 70 | 71 | ```bash 72 | conda create -n vllm python=3.10 -y 73 | conda activate vllm 74 | pip install vllm # Ensure vllm>=0.6.2 for compatibility with llama3.2; if llama-3.2 is not used, vllm==0.6.1 is sufficient 75 | 76 | conda deactivate 77 | ``` 78 | 79 | ## Domain-Specific Visual Instruction Synthesis 80 | The steps in [Synthesis.md](docs/Synthesis.md) reproduce our visual instruction synthesizer and our synthetic data. 81 | 82 | ## Domain-Specific Single-Stage Post-Training 83 | The steps in [Post-Train.md](docs/Post_Train.md) reproduce our domain-adapted models. 84 | 85 | ## Domain-Specific Task Evaluation 86 | 87 | See [Evaluation.md](docs/Evaluation.md) to reproduce our results and evaluate any MLLMs compatible with vLLM. 88 | 89 | 90 | ## License 91 | 92 | ```text 93 | LICENSE AGREEMENT 94 | Last revision: Sep, 2023 95 | You are granted the right to use the code and/or Database under the following terms, as enlisted in this document (“Beijing Institute for General Artificial Intelligence BIGAI License Agreement”): 96 | · The code and/or data is strictly for non-commercial academic research only. 97 | · Any commercial use of the code or data requires prior contact and negotiation of licensing fees with the original authors or Beijing Institute for General Artificial Intelligence (BIGAI). 98 | · Any new access to the code and/or data shall be established through this form or the official method of distributing the code and/or data. The code and/or data may not be redistributed, in whole or part, or in any format without written prior permission. A reference to the code and/or data or this License Agreement must be made if you publish information. 99 | · The code and/or data is provided as is. No liability or responsibility assumed for the authors. 100 | · The right to revise this License Agreement, in whole or part, at any time without prior notice is reserved by the authors. 101 | · You warrant that you have the authorization to enter into this License Agreement. 102 | · You comply with the terms enforced by the corporates whose products were used in collecting the code and/or data. The terms unanimously enforce, including but not limited to, restricting the use of the code and/or data to non-commercial academic research. 103 | ``` 104 | 105 | ## Citation 106 | If you find our work helpful, please cite us: 107 | 108 | ```bibtex 109 | @article{adamllm, 110 | title={On Domain-Specific Post-Training for Multimodal Large Language Models}, 111 | author={Cheng, Daixuan and Huang, Shaohan and Zhu, Ziyu and Zhang, Xintong and Zhao, Wayne Xin and Luan, Zhongzhi and Dai, Bo and Zhang, Zhenliang}, 112 | journal={arXiv preprint arXiv:2411.19930}, 113 | year={2024} 114 | } 115 | @inproceedings{ 116 | adaptllm, 117 | title={Adapting Large Language Models via Reading Comprehension}, 118 | author={Daixuan Cheng and Shaohan Huang and Furu Wei}, 119 | booktitle={The Twelfth International Conference on Learning Representations}, 120 | year={2024}, 121 | url={https://openreview.net/forum?id=y886UXPEZ0} 122 | } 123 | ``` 124 | -------------------------------------------------------------------------------- /data_samples/images/0ae1b33f14e0aff353de_300.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/0ae1b33f14e0aff353de_300.jpg -------------------------------------------------------------------------------- /data_samples/images/214320-6696759ae28745f696a6830cb67c954d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/214320-6696759ae28745f696a6830cb67c954d.jpg -------------------------------------------------------------------------------- /data_samples/images/34393_1719273861489_1227549162_31979863_8203919_n-410x307.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/34393_1719273861489_1227549162_31979863_8203919_n-410x307.jpg -------------------------------------------------------------------------------- /data_samples/images/4484931-a3606015e6614481afc6a433fe49387d.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/4484931-a3606015e6614481afc6a433fe49387d.jpg -------------------------------------------------------------------------------- /data_samples/images/51178860_creamy-slaw_1x1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/51178860_creamy-slaw_1x1.jpg -------------------------------------------------------------------------------- /data_samples/images/5241179-01bbc4f291a3480c8e04c0616a000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/5241179-01bbc4f291a3480c8e04c0616a000011.jpg -------------------------------------------------------------------------------- /data_samples/images/942578-80eff85811684394bc64592dee071d56.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/942578-80eff85811684394bc64592dee071d56.jpg -------------------------------------------------------------------------------- /data_samples/images/970916-518557352d32499fb9d66ba00cc35963.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/970916-518557352d32499fb9d66ba00cc35963.jpg -------------------------------------------------------------------------------- /data_samples/images/Chocolate-Covered-Strawberry-Cheesecake_WD0156-S-29da911d0e1a414b9e1e7b2c800f6bde.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/Chocolate-Covered-Strawberry-Cheesecake_WD0156-S-29da911d0e1a414b9e1e7b2c800f6bde.jpg -------------------------------------------------------------------------------- /data_samples/images/Pumpkin-Waffles-with-Bacon-Maple-Butter-by-Gerry-Speirs-410x273.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/Pumpkin-Waffles-with-Bacon-Maple-Butter-by-Gerry-Speirs-410x273.jpeg -------------------------------------------------------------------------------- /data_samples/images/Serrano-Spinach-Stuffed-Tomatoes-410x316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/Serrano-Spinach-Stuffed-Tomatoes-410x316.jpg -------------------------------------------------------------------------------- /data_samples/images/Small-Plates-No-Boil-Baked-Pasta-inset-in-pot-11012019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/Small-Plates-No-Boil-Baked-Pasta-inset-in-pot-11012019.jpg -------------------------------------------------------------------------------- /data_samples/images/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__20120513-206173-spicy-brown-mustard-step-2-37f1f2516f214b6dbae208bbeef10380.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/__opt__aboutcom__coeus__resources__content_migration__serious_eats__seriouseats.com__recipes__20120513-206173-spicy-brown-mustard-step-2-37f1f2516f214b6dbae208bbeef10380.jpg -------------------------------------------------------------------------------- /data_samples/images/classic-beef-stew-dutch-oven-recipe-FT-BLOG1019-c50074af5c1e4c3e865bf0a4d2c36266.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/classic-beef-stew-dutch-oven-recipe-FT-BLOG1019-c50074af5c1e4c3e865bf0a4d2c36266.jpg -------------------------------------------------------------------------------- /data_samples/images/garlicky-green-beans1-410x329.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/garlicky-green-beans1-410x329.jpg -------------------------------------------------------------------------------- /data_samples/images/picCYxfUw.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picCYxfUw.jpg -------------------------------------------------------------------------------- /data_samples/images/picIfFx2r.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picIfFx2r.jpg -------------------------------------------------------------------------------- /data_samples/images/picJbKQVD.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picJbKQVD.jpg -------------------------------------------------------------------------------- /data_samples/images/picKiyZk5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picKiyZk5.jpg -------------------------------------------------------------------------------- /data_samples/images/picVapP1X.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picVapP1X.jpg -------------------------------------------------------------------------------- /data_samples/images/picbgkmgN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/picbgkmgN.jpg -------------------------------------------------------------------------------- /data_samples/images/pumpkinspice-cake-FG2-410x273.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/pumpkinspice-cake-FG2-410x273.jpg -------------------------------------------------------------------------------- /data_samples/images/tWEWzgEySPijBjppP0bw-IMG_20140420_121943178.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/tWEWzgEySPijBjppP0bw-IMG_20140420_121943178.jpg -------------------------------------------------------------------------------- /data_samples/images/tuna9-410x274.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/data_samples/images/tuna9-410x274.jpg -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Domain-Specific Task Evaluation 2 | We provide all the resources necessary to reproduce our results and evaluate any MLLMs compatible with [vLLM](https://github.com/vllm-project/vllm). 3 | ## Model Zoo 4 | 5 | | Model | Repo ID in HF 🤗 | Domain | Base Model | Training Data | Evaluation Benchmark | 6 | |:----------------------------------------------------------------------------|:--------------------------------------------|:--------------|:-------------------------|:------------------------------------------------------------------------------------------------|-----------------------| 7 | | [AdaMLLM-med-2B](https://huggingface.co/AdaptLLM/biomed-Qwen2-VL-2B-Instruct) | AdaptLLM/biomed-Qwen2-VL-2B-Instruct | Biomedicine | Qwen2-VL-2B-Instruct | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 8 | | [AdaMLLM-food-2B](https://huggingface.co/AdaptLLM/food-Qwen2-VL-2B-Instruct) | AdaptLLM/food-Qwen2-VL-2B-Instruct | Food | Qwen2-VL-2B-Instruct | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 9 | | [AdaMLLM-med-8B](https://huggingface.co/AdaptLLM/biomed-LLaVA-NeXT-Llama3-8B) | AdaptLLM/biomed-LLaVA-NeXT-Llama3-8B | Biomedicine | open-llava-next-llama3-8b | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 10 | | [AdaMLLM-food-8B](https://huggingface.co/AdaptLLM/food-LLaVA-NeXT-Llama3-8B) |AdaptLLM/food-LLaVA-NeXT-Llama3-8B | Food | open-llava-next-llama3-8b | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 11 | | [AdaMLLM-med-11B](https://huggingface.co/AdaptLLM/biomed-Llama-3.2-11B-Vision-Instruct) | AdaptLLM/biomed-Llama-3.2-11B-Vision-Instruct | Biomedicine | Llama-3.2-11B-Vision-Instruct | [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) | [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) | 12 | | [AdaMLLM-food-11B](https://huggingface.co/AdaptLLM/food-Llama-3.2-11B-Vision-Instruct) | AdaptLLM/food-Llama-3.2-11B-Vision-Instruct | Food | Llama-3.2-11B-Vision-Instruct | [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) | [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) | 13 | 14 | 15 | ## Task Datasets 16 | To simplify the evaluation on domain-specific tasks, we have uploaded the templatized test sets for each task: 17 | 18 | - [biomed-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/biomed-VQA-benchmark) 19 | - [food-VQA-benchmark](https://huggingface.co/datasets/AdaptLLM/food-VQA-benchmark) 20 | 21 | The dataset loading script is embedded in the inference code, so you can directly run the following commands to evaluate MLLMs. 22 | 23 | ## Evaluate Any MLLM Compatible with vLLM 24 | 25 | Our code can directly evaluate models such as LLaVA-v1.6 ([open-source version](https://huggingface.co/Lin-Chen/open-llava-next-llama3-8b)), Qwen2-VL, and Llama-3.2-Vision. To evaluate other MLLMs, refer to [this guide](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference_vision_language.py) for modifying the `BaseTask` class in the [vllm_inference/utils/task.py](../vllm_inference/utils/task.py) file. Feel free to reach out to us for assistance! 26 | 27 | ### Setup 28 | ```bash 29 | conda activate vllm 30 | cd QA-Synthesizer/vllm_inference 31 | RESULTS_DIR=./eval_results # Directory for saving evaluation scores 32 | ``` 33 | 34 | ### Biomedicine Domain 35 | 36 | ```bash 37 | # Choose from ['med', 'PMC_VQA', 'VQA_RAD', 'SLAKE', 'PathVQA'] 38 | # 'med' runs inference on all biomedicine tasks; others run on a single task 39 | DOMAIN='med' 40 | 41 | # 1. LLaVA-v1.6-8B 42 | MODEL_TYPE='llava' 43 | MODEL=AdaptLLM/biomed-LLaVA-NeXT-Llama3-8B # HuggingFace repo ID for AdaMLLM-med-8B 44 | OUTPUT_DIR=./output/AdaMLLM-med-LLaVA-8B_${DOMAIN} 45 | 46 | # Run inference with data parallelism; adjust CUDA devices as needed 47 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 48 | 49 | # 2. Qwen2-VL-2B 50 | MODEL_TYPE='qwen2_vl' 51 | MODEL=Qwen/Qwen2-VL-2B-Instruct # HuggingFace repo ID for Qwen2-VL 52 | OUTPUT_DIR=./output/Qwen2-VL-2B-Instruct_${DOMAIN} 53 | 54 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 55 | 56 | MODEL=AdaptLLM/biomed-Qwen2-VL-2B-Instruct # HuggingFace repo ID for AdaMLLM-med-2B 57 | OUTPUT_DIR=./output/AdaMLLM-med-Qwen-2B_${DOMAIN} 58 | 59 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 60 | 61 | # 3. Llama-3.2-11B 62 | MODEL_TYPE='mllama' 63 | MODEL=meta-llama/Llama-3.2-11B-Vision-Instruct # HuggingFace repo ID for Llama3.2 64 | OUTPUT_DIR=./output/Llama-3.2-11B-Vision-Instruct_${DOMAIN} 65 | 66 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 67 | 68 | MODEL=AdaptLLM/biomed-Llama-3.2-11B-Vision-Instruct # HuggingFace repo ID for AdaMLLM-11B 69 | OUTPUT_DIR=./output/AdaMLLM-med-Llama3.2-11B_${DOMAIN} 70 | 71 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 72 | ``` 73 | 74 | ### Food Domain 75 | 76 | ```bash 77 | # Choose from ['food', 'Recipe1M', 'Nutrition5K', 'Food101', 'FoodSeg103'] 78 | # 'food' runs inference on all food tasks; others run on a single task 79 | DOMAIN='food' 80 | 81 | # 1. LLaVA-v1.6-8B 82 | MODEL_TYPE='llava' 83 | MODEL=AdaptLLM/food-LLaVA-NeXT-Llama3-8B # HuggingFace repo ID for AdaMLLM-food-8B 84 | OUTPUT_DIR=./output/AdaMLLM-food-LLaVA-8B_${DOMAIN} 85 | 86 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 87 | 88 | # 2. Qwen2-VL-2B 89 | MODEL_TYPE='qwen2_vl' 90 | MODEL=Qwen/Qwen2-VL-2B-Instruct # HuggingFace repo ID for Qwen2-VL 91 | OUTPUT_DIR=./output/Qwen2-VL-2B-Instruct_${DOMAIN} 92 | 93 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 94 | 95 | MODEL=AdaptLLM/food-Qwen2-VL-2B-Instruct # HuggingFace repo ID for AdaMLLM-food-2B 96 | OUTPUT_DIR=./output/AdaMLLM-food-Qwen-2B_${DOMAIN} 97 | 98 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 99 | 100 | # 3. Llama-3.2-11B 101 | MODEL_TYPE='mllama' 102 | MODEL=meta-llama/Llama-3.2-11B-Vision-Instruct # HuggingFace repo ID for Llama3.2 103 | OUTPUT_DIR=./output/Llama-3.2-11B-Vision-Instruct_${DOMAIN} 104 | 105 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 106 | 107 | MODEL=AdaptLLM/food-Llama-3.2-11B-Vision-Instruct # HuggingFace repo ID for AdaMLLM-food-11B 108 | OUTPUT_DIR=./output/AdaMLLM-food-Llama3.2-2B_${DOMAIN} 109 | 110 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_inference.sh ${MODEL} ${DOMAIN} ${MODEL_TYPE} ${OUTPUT_DIR} ${RESULTS_DIR} 111 | ``` 112 | 113 | ## Results 114 | 115 | The evaluation results are stored in `./eval_results`, and the model prediction outputs are in `./output`. -------------------------------------------------------------------------------- /docs/Post_Train.md: -------------------------------------------------------------------------------- 1 | # Post-Training General MLLMs 2 | ## Download or Reproduce Synthetic Training Datasets 3 | 4 | You may follow [Synthesis.md](./Synthesis.md) to reproduce our training datasets. 5 | Or you can skip it and download the resulting synthetic data from: 6 | 7 | - [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) 8 | - [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) 9 | 10 | ## To post-train LLava-v1.6-Llama3-8B 11 | Use the `adamllm` environment to post-train LLava-v1.6-Llama3-8B ([the open-source version](https://huggingface.co/Lin-Chen/open-llava-next-llama3-8b)) 12 | 13 | ```bash 14 | cd QA-Synthesizer 15 | conda activate adamllm 16 | 17 | # Biomedicine domain, using PMC^Refined caption 18 | DOMAIN=biomed 19 | DATASET=PATH_TO/biomed-visual-instructions/image_caption_and_synthetic_task.json 20 | IMAGE_FOLDER=PATH_TO/biomed-visual-instructions/images 21 | 22 | bash ./scripts/post_train_mllm.sh ${DOMAIN} ${DATASET} ${IMAGE_FOLDER} 23 | 24 | # Food domain 25 | DOMAIN=food 26 | DATASET=PATH_TO/food-visual-instructions/image_caption_and_synthetic_task.json 27 | IMAGE_FOLDER=PATH_TO/food-visual-instructions/images 28 | 29 | bash ./scripts/post_train_mllm.sh ${DOMAIN} ${DATASET} ${IMAGE_FOLDER} 30 | 31 | conda deactivate 32 | ``` 33 | 34 | ## To post-train Qwen2-VL-2B-Instruct and Llama-3.2-11B-Vision-Instruct 35 | 36 | ### Update Dataset Information 37 | 38 | 1. **Edit `dataset_info.json`**: 39 | Update the `file_name` field in the [dataset_info.json](../scripts/dataset_info.json) to point to the paths of your biomed and food training data. 40 | 41 | ```json 42 | { 43 | "biomed": { 44 | "file_name": "PATH_TO/biomed-visual-instructions/image_caption_and_synthetic_task.json", // Replace with your file path 45 | "formatting": "sharegpt", 46 | "columns": { 47 | "messages": "messages", 48 | "images": "images" 49 | }, 50 | "tags": { 51 | "role_tag": "role", 52 | "content_tag": "content", 53 | "user_tag": "user", 54 | "assistant_tag": "assistant" 55 | } 56 | }, 57 | "food": { 58 | "file_name": "PATH_TO/food-visual-instructions/image_caption_and_synthetic_task.json", // Replace with your file path 59 | "formatting": "sharegpt", 60 | "columns": { 61 | "messages": "messages", 62 | "images": "images" 63 | }, 64 | "tags": { 65 | "role_tag": "role", 66 | "content_tag": "content", 67 | "user_tag": "user", 68 | "assistant_tag": "assistant" 69 | } 70 | } 71 | } 72 | ``` 73 | 74 | 2. **Copy `dataset_info.json` to the respective image folders**: 75 | Use the following commands to copy `dataset_info.json` to the image folders for both biomed and food domains: 76 | 77 | ```bash 78 | # Biomed domain 79 | IMAGE_FOLDER=PATH_TO/biomed-visual-instructions/images 80 | cp ./scripts/dataset_info.json ${IMAGE_FOLDER}/ -v 81 | 82 | # Food domain 83 | IMAGE_FOLDER=PATH_TO/food-visual-instructions/images 84 | cp ./scripts/dataset_info.json ${IMAGE_FOLDER}/ -v 85 | ``` 86 | 87 | 88 | ### Set up environment for LLaMA-Factory and Qwen2-VL 89 | We utilize [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory?tab=readme-ov-file) for experiments on Qwen2-VL-2B-Instruct and Llama-3.2-11B-Vision-Instruct. 90 | 91 | ```bash 92 | git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git 93 | cd LLaMA-Factory 94 | 95 | conda create -n llama-factory python=3.10 -y 96 | conda activate llama-factory 97 | pip install -e ".[torch,metrics]" 98 | 99 | # You may need to update the following packages for qwen-vl 100 | # pip install trl==0.9.6 accelerate==0.34.0 git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830 101 | 102 | pip install qwen-vl-utils 103 | ``` 104 | 105 | ### To Post-Train Qwen2-VL-2B 106 | 107 | ```bash 108 | BASE_MODEL=Qwen/Qwen2-VL-2B-Instruct 109 | DATASET=food # Choose from [biomed, food] 110 | IMAGE_FOLDER=PATH_TO/food-visual-instructions/images # Choose from [biomed-visual-instructions/images, food-visual-instructions/images] 111 | BATCH_SIZE=8 112 | GRADIENT_ACCU_STEPS=2 113 | OUTPUT_PATH=./exp/${DATASET}-Qwen2-VL-2B 114 | 115 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' python -m torch.distributed.launch --use_env --nproc_per_node=8 --master_port=12345 src/train.py \ 116 | --deepspeed examples/deepspeed/ds_z2_config.json \ 117 | --stage sft \ 118 | --do_train \ 119 | --model_name_or_path ${BASE_MODEL} \ 120 | --dataset ${DATASET} \ 121 | --template qwen2_vl \ 122 | --finetuning_type full \ 123 | --output_dir ${OUTPUT_PATH} \ 124 | --overwrite_cache \ 125 | --warmup_ratio 0.1 \ 126 | --weight_decay 0.1 \ 127 | --per_device_train_batch_size ${BATCH_SIZE} \ 128 | --gradient_accumulation_steps ${GRADIENT_ACCU_STEPS} \ 129 | --ddp_timeout 180000000 \ 130 | --learning_rate 1e-5 \ 131 | --lr_scheduler_type cosine \ 132 | --logging_steps 10 \ 133 | --cutoff_len 6144 \ 134 | --save_steps 500 \ 135 | --num_train_epochs 1 \ 136 | --bf16 \ 137 | --report_to none \ 138 | --save_total_limit 1 \ 139 | --preprocessing_num_workers 32 \ 140 | --dataset_dir ${IMAGE_FOLDER} 141 | 142 | # Copy configuration files for task evaluation 143 | cp ${BASE_MODEL}/chat_template.json ${OUTPUT_PATH}/chat_template.json -v 144 | cp ${BASE_MODEL}/merges.txt ${OUTPUT_PATH}/merges.txt -v 145 | cp ${BASE_MODEL}/tokenizer_config.json ${OUTPUT_PATH}/tokenizer_config.json -v 146 | cp ${BASE_MODEL}/tokenizer.json ${OUTPUT_PATH}/tokenizer.json -v 147 | cp ${BASE_MODEL}/vocab.json ${OUTPUT_PATH}/vocab.json -v 148 | ``` 149 | 150 | ### To Post-Train Llama-3.2-11B-Vision-Instruct 151 | 152 | ```bash 153 | BASE_MODEL=meta-llama/Llama-3.2-11B-Vision-Instruct 154 | DATASET=food # Choose from [biomed, food] 155 | IMAGE_FOLDER=PATH_TO/food-visual-instructions/images # Choose from [biomed-visual-instructions/images, food-visual-instructions/images] 156 | BATCH_SIZE=4 157 | GRADIENT_ACCU_STEPS=4 158 | OUTPUT_PATH=./exp/${DATASET}-Llama-3.2-11B-Vision 159 | 160 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' python -m torch.distributed.launch --use_env --nproc_per_node=8 --master_port=12345 src/train.py \ 161 | --deepspeed examples/deepspeed/ds_z3_config.json \ 162 | --stage sft \ 163 | --do_train \ 164 | --model_name_or_path ${BASE_MODEL} \ 165 | --dataset ${DATASET} \ 166 | --template llama3_vl \ 167 | --finetuning_type full \ 168 | --output_dir ${OUTPUT_PATH} \ 169 | --overwrite_cache \ 170 | --warmup_ratio 0.1 \ 171 | --weight_decay 0.1 \ 172 | --per_device_train_batch_size ${BATCH_SIZE} \ 173 | --gradient_accumulation_steps ${GRADIENT_ACCU_STEPS} \ 174 | --ddp_timeout 180000000 \ 175 | --learning_rate 1e-5 \ 176 | --lr_scheduler_type cosine \ 177 | --logging_steps 10 \ 178 | --cutoff_len 6144 \ 179 | --save_steps 500 \ 180 | --num_train_epochs 1 \ 181 | --bf16 \ 182 | --report_to none \ 183 | --save_total_limit 1 \ 184 | --preprocessing_num_workers 32 \ 185 | --dataset_dir ${IMAGE_FOLDER} 186 | 187 | # Copy configuration files for task evaluation 188 | cp ${BASE_MODEL}/chat_template.json ${OUTPUT_PATH}/chat_template.json -v 189 | cp ${BASE_MODEL}/special_tokens_map.json ${OUTPUT_PATH}/special_tokens_map.json -v 190 | cp ${BASE_MODEL}/tokenizer_config.json ${OUTPUT_PATH}/tokenizer_config.json -v 191 | cp ${BASE_MODEL}/tokenizer.json ${OUTPUT_PATH}/tokenizer.json -v 192 | ``` -------------------------------------------------------------------------------- /docs/Synthesis.md: -------------------------------------------------------------------------------- 1 | # Visual Instruction Synthesis 2 | 3 | ## 1. Fine-Tuning Visual Instruction Synthesizer 4 | We fine-tune a unified visual instruction synthesizer that generates diverse tasks based on image-caption pairs across various domains. 5 | 6 | The following steps reproduce our visual instruction synthesizer. Alternatively, you can skip these steps and download our synthesizer from [AdaptLLM/visual-instruction-synthesizer](https://huggingface.co/AdaptLLM/visual-instruction-synthesizer). 7 | 8 | ### Download Seed Data 9 | 10 | We combine VisionFLAN and ALLaVA into our required format for fine-tuning the synthesizer. 11 | 12 | Download the following data files: 13 | - VisionFLAN: 14 | * [vflan_metadata.json](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/blob/main/metadata.json) 15 | * [images_191task_1k](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/blob/main/image_191-task_1k.zip) 16 | - ALLaVA: 17 | * [ALLaVA-Instruct-VFLAN-4V.json](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/blob/main/allava_vflan/ALLaVA-Instruct-VFLAN-4V.json) 18 | * [ALLaVA-Caption-VFLAN-4V.json](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/blob/main/allava_vflan/ALLaVA-Caption-VFLAN-4V.json) 19 | 20 | ### Fine-Tune Synthesizer 21 | 22 | Using the seed data, we conduct multitask fine-tuning on an open-source MLLM (e.g., LLaVA-v1.6-8B) to generate task triplets based on the corresponding image-caption pairs, and 10% of the images are replaced with a blank image to enhance generalization. 23 | 24 | ```bash 25 | conda activate adamllm 26 | 27 | CAPTION=PATH_TO/ALLaVA-Caption-VFLAN-4V.json 28 | PRECISE_A=PATH_TO/vflan_metadata.json 29 | INFORMATIVE_A=PATH_TO/ALLaVA-Instruct-VFLAN-4V.json 30 | IAMGE_FOLDER=PATH_TO/images_191task_1k 31 | 32 | bash ./scripts/tune_synthesizer.sh ${CAPTION} ${PRECISE_A} ${INFORMATIVE_A} ${IAMGE_FOLDER} 33 | 34 | conda deactivate 35 | ``` 36 | 37 | The tuned synthesizer is saved as `./exp/synthesizer`. 38 | 39 | ## 2. Task Synthesis for Target Domain 40 | We use the synthesizer to generate task triplets from image-caption pairs in the target domain, followed by consistency-based data filtering to enhance data quality. 41 | 42 | The following steps reproduce our data. You can also skip them and download the resulting synthetic data (including `image_caption_and_synthetic_task.json` and `images`) from: 43 | - [biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) 44 | - [food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) 45 | 46 | #### Setup 47 | ```bash 48 | conda activate vllm 49 | cd QA-Synthesizer/vllm_inference 50 | SYNTHESIZER=AdaptLLM/visual-instruction-synthesizer # Path to the synthesizer 51 | CONSISTENCY_CHECKER=meta-llama/Meta-Llama-3-8B # Language model for consistency checks 52 | ``` 53 | 54 | #### **Quick Try with Data Samples** 55 | We have included a few [data samples](../data_samples) in this repository for a quick try: 56 | ```bash 57 | IMAGE_CAPTION='../data_samples/image_caption_pairs.json' # Path to the image-caption pairs 58 | IMAGE_FOLDER='../data_samples/images' # Path to the image folder 59 | OUTPUT_DIR='../data_samples/' # Output directory for synthesized data 60 | 61 | # Run synthesis with data parallelism; adjust CUDA devices as needed: 62 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_synthesis.sh ${SYNTHESIZER} ${CONSISTENCY_CHECKER} ${IMAGE_CAPTION} ${IMAGE_FOLDER} ${OUTPUT_DIR} 63 | ``` 64 | 65 | #### **Biomedicine** 66 | 1. download the `image_caption_pairs.json` file and `images` from [AdaptLLM/biomed-visual-instructions](https://huggingface.co/datasets/AdaptLLM/biomed-visual-instructions) 67 | 68 | 2. Then run 69 | ```bash 70 | IMAGE_CAPTION="./biomed-visual-instructions/image_caption_pairs.json" 71 | IMAGE_FOLDER="./biomed-visual-instructions/images" 72 | OUTPUT_DIR="./biomed-visual-instructions" 73 | 74 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_synthesis.sh ${SYNTHESIZER} ${CONSISTENCY_CHECKER} ${IMAGE_CAPTION} ${IMAGE_FOLDER} ${OUTPUT_DIR} 75 | ``` 76 | 77 | #### **Food** 78 | 1. download the `image_caption_pairs.json` file and `images` from [AdaptLLM/food-visual-instructions](https://huggingface.co/datasets/AdaptLLM/food-visual-instructions) 79 | 80 | 2. Then run 81 | ```bash 82 | IMAGE_CAPTION="./food-visual-instructions/image_caption_pairs.json" 83 | IMAGE_FOLDER="./food-visual-instructions/images" 84 | OUTPUT_DIR="./food-visual-instructions" 85 | 86 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' bash run_synthesis.sh ${SYNTHESIZER} ${CONSISTENCY_CHECKER} ${IMAGE_CAPTION} ${IMAGE_FOLDER} ${OUTPUT_DIR} 87 | ``` 88 | 89 | The synthesized output for single-stage post-training will be saved at: `${OUTPUT_DIR}/image_caption_and_synthetic_task.json` -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 187 | 188 | def insert_separator(X, sep): 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 190 | 191 | input_ids = [] 192 | offset = 0 193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 194 | offset = 1 195 | input_ids.append(prompt_chunks[0][0]) 196 | 197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 198 | input_ids.extend(x[offset:]) 199 | 200 | if return_tensors is not None: 201 | if return_tensors == 'pt': 202 | return torch.tensor(input_ids, dtype=torch.long) 203 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 204 | return input_ids 205 | 206 | 207 | def get_model_name_from_path(model_path): 208 | model_path = model_path.strip("/") 209 | model_paths = model_path.split("/") 210 | if model_paths[-1].startswith('checkpoint-'): 211 | return model_paths[-2] + "_" + model_paths[-1] 212 | else: 213 | return model_paths[-1] 214 | 215 | class KeywordsStoppingCriteria(StoppingCriteria): 216 | def __init__(self, keywords, tokenizer, input_ids): 217 | self.keywords = keywords 218 | self.keyword_ids = [] 219 | self.max_keyword_len = 0 220 | for keyword in keywords: 221 | cur_keyword_ids = tokenizer(keyword).input_ids 222 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 223 | cur_keyword_ids = cur_keyword_ids[1:] 224 | if len(cur_keyword_ids) > self.max_keyword_len: 225 | self.max_keyword_len = len(cur_keyword_ids) 226 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 227 | self.tokenizer = tokenizer 228 | self.start_len = input_ids.shape[1] 229 | 230 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 231 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 232 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 233 | for keyword_id in self.keyword_ids: 234 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 235 | if torch.equal(truncated_output_ids, keyword_id): 236 | return True 237 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 238 | for keyword in self.keywords: 239 | if keyword in outputs: 240 | return True 241 | return False 242 | 243 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 244 | outputs = [] 245 | for i in range(output_ids.shape[0]): 246 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 247 | return all(outputs) 248 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | # get model_name for checking model type 49 | model_name = AutoConfig.from_pretrained(model_path, trust_remote_code=True).model_type 50 | 51 | if 'llava' in model_name.lower(): 52 | # Load LLaVA model 53 | if 'lora' in model_name.lower() and model_base is None: 54 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 55 | if 'lora' in model_name.lower() and model_base is not None: 56 | from llava.model.language_model.llava_llama import LlavaConfig 57 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) 58 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 59 | print('Loading LLaVA from base model...') 60 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 61 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 62 | if model.lm_head.weight.shape[0] != token_num: 63 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 64 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 65 | 66 | print('Loading additional LLaVA weights...') 67 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 68 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 69 | else: 70 | # this is probably from HF Hub 71 | from huggingface_hub import hf_hub_download 72 | def load_from_hf(repo_id, filename, subfolder=None): 73 | cache_file = hf_hub_download( 74 | repo_id=repo_id, 75 | filename=filename, 76 | subfolder=subfolder) 77 | return torch.load(cache_file, map_location='cpu') 78 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 79 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 80 | if any(k.startswith('model.model.') for k in non_lora_trainables): 81 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 82 | model.load_state_dict(non_lora_trainables, strict=False) 83 | 84 | from peft import PeftModel 85 | print('Loading LoRA weights...') 86 | model = PeftModel.from_pretrained(model, model_path) 87 | print('Merging LoRA weights...') 88 | model = model.merge_and_unload() 89 | print('Model is loaded...') 90 | elif model_base is not None: 91 | # this may be mm projector only 92 | print('Loading LLaVA from base model...') 93 | if 'mpt' in model_name.lower(): 94 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 95 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 96 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 97 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 98 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 99 | else: 100 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 101 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 102 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 103 | 104 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 105 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 106 | model.load_state_dict(mm_projector_weights, strict=False) 107 | else: 108 | if 'mpt' in model_name.lower(): 109 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 110 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 111 | elif 'mistral' in model_name.lower(): 112 | tokenizer = AutoTokenizer.from_pretrained(model_path) 113 | model = LlavaMistralForCausalLM.from_pretrained( 114 | model_path, 115 | low_cpu_mem_usage=True, 116 | **kwargs 117 | ) 118 | else: 119 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 120 | model = LlavaLlamaForCausalLM.from_pretrained( 121 | model_path, 122 | low_cpu_mem_usage=True, 123 | **kwargs 124 | ) 125 | else: 126 | raise Warning(f'model_name: {model_name} not a llava model. please ensure this is a llava model') 127 | # Load language model 128 | if model_base is not None: 129 | # PEFT model 130 | from peft import PeftModel 131 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 132 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 133 | print(f"Loading LoRA weights from {model_path}") 134 | model = PeftModel.from_pretrained(model, model_path) 135 | print(f"Merging weights") 136 | model = model.merge_and_unload() 137 | print('Convert to FP16...') 138 | model.to(torch.float16) 139 | else: 140 | use_fast = False 141 | if 'mpt' in model_name.lower(): 142 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 143 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 144 | else: 145 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 146 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 147 | 148 | image_processor = None 149 | 150 | if 'llava' in model_name.lower(): 151 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 152 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 153 | if mm_use_im_patch_token: 154 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 155 | if mm_use_im_start_end: 156 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 157 | model.resize_token_embeddings(len(tokenizer)) 158 | 159 | vision_tower = model.get_vision_tower() 160 | if not vision_tower.is_loaded: 161 | vision_tower.load_model(device_map=device_map) 162 | if device_map != 'auto': 163 | vision_tower.to(device=device_map, dtype=torch.float16) 164 | image_processor = vision_tower.image_processor 165 | 166 | if hasattr(model.config, "max_sequence_length"): 167 | context_len = model.config.max_sequence_length 168 | else: 169 | context_len = 2048 170 | 171 | return tokenizer, model, image_processor, context_len 172 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_llama", LlavaConfig) 158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 's2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 12 | else: 13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | 15 | raise ValueError(f'Unknown vision tower: {vision_tower}') 16 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 29 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 30 | self.vision_tower.requires_grad_(False) 31 | 32 | self.is_loaded = True 33 | 34 | def feature_select(self, image_forward_outs): 35 | image_features = image_forward_outs.hidden_states[self.select_layer] 36 | if self.select_feature == 'patch': 37 | image_features = image_features[:, 1:] 38 | elif self.select_feature == 'cls_patch': 39 | image_features = image_features 40 | else: 41 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 42 | return image_features 43 | 44 | # @torch.no_grad() 45 | def forward(self, images): 46 | if type(images) is list: 47 | image_features = [] 48 | for image in images: 49 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 50 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 51 | image_features.append(image_feature) 52 | else: 53 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 54 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 55 | 56 | return image_features 57 | 58 | @property 59 | def dummy_feature(self): 60 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 61 | 62 | @property 63 | def dtype(self): 64 | return self.vision_tower.dtype 65 | 66 | @property 67 | def device(self): 68 | return self.vision_tower.device 69 | 70 | @property 71 | def config(self): 72 | if self.is_loaded: 73 | return self.vision_tower.config 74 | else: 75 | return self.cfg_only 76 | 77 | @property 78 | def hidden_size(self): 79 | return self.config.hidden_size 80 | 81 | @property 82 | def num_patches_per_side(self): 83 | return self.config.image_size // self.config.patch_size 84 | 85 | @property 86 | def num_patches(self): 87 | return (self.config.image_size // self.config.patch_size) ** 2 88 | 89 | 90 | 91 | class CLIPVisionTowerS2(CLIPVisionTower): 92 | def __init__(self, vision_tower, args, delay_load=False): 93 | super().__init__(vision_tower, args, delay_load) 94 | 95 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 96 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 97 | self.s2_scales.sort() 98 | self.s2_split_size = self.s2_scales[0] 99 | self.s2_image_size = self.s2_scales[-1] 100 | 101 | try: 102 | from s2wrapper import forward as multiscale_forward 103 | except ImportError: 104 | raise ImportError('Package s2wrapper not found! Please install') 105 | self.multiscale_forward = multiscale_forward 106 | 107 | # change resize/crop size in preprocessing to the largest image size in s2_scale 108 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 109 | self.image_processor.size['shortest_edge'] = self.s2_image_size 110 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 111 | 112 | def load_model(self, device_map=None): 113 | if self.is_loaded: 114 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 115 | return 116 | 117 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 118 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 119 | self.vision_tower.requires_grad_(False) 120 | 121 | self.image_processor.size['shortest_edge'] = self.s2_image_size 122 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 123 | 124 | self.is_loaded = True 125 | 126 | @torch.no_grad() 127 | def forward_feature(self, images): 128 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 129 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 130 | return image_features 131 | 132 | @torch.no_grad() 133 | def forward(self, images): 134 | if type(images) is list: 135 | image_features = [] 136 | for image in images: 137 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 138 | image_features.append(image_feature) 139 | else: 140 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 141 | 142 | return image_features 143 | 144 | @property 145 | def hidden_size(self): 146 | return self.config.hidden_size * len(self.s2_scales) 147 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "mistral" in model_name.lower(): 37 | conv_mode = "mistral_instruct" 38 | elif "v1.6-34b" in model_name.lower(): 39 | conv_mode = "chatml_direct" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | image = load_image(args.image_file) 59 | image_size = image.size 60 | # Similar operation in model_worker.py 61 | image_tensor = process_images([image], image_processor, model.config) 62 | if type(image_tensor) is list: 63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 64 | else: 65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 66 | 67 | while True: 68 | try: 69 | inp = input(f"{roles[0]}: ") 70 | except EOFError: 71 | inp = "" 72 | if not inp: 73 | print("exit...") 74 | break 75 | 76 | print(f"{roles[1]}: ", end="") 77 | 78 | if image is not None: 79 | # first message 80 | if model.config.mm_use_im_start_end: 81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 82 | else: 83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 84 | image = None 85 | 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompt = conv.get_prompt() 89 | 90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 94 | 95 | with torch.inference_mode(): 96 | output_ids = model.generate( 97 | input_ids, 98 | images=image_tensor, 99 | image_sizes=[image_size], 100 | do_sample=True if args.temperature > 0 else False, 101 | temperature=args.temperature, 102 | max_new_tokens=args.max_new_tokens, 103 | streamer=streamer, 104 | use_cache=True) 105 | 106 | outputs = tokenizer.decode(output_ids[0]).strip() 107 | conv.messages[-1][-1] = outputs 108 | 109 | if args.debug: 110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 116 | parser.add_argument("--model-base", type=str, default=None) 117 | parser.add_argument("--image-file", type=str, required=True) 118 | parser.add_argument("--device", type=str, default="cuda") 119 | parser.add_argument("--conv-mode", type=str, default=None) 120 | parser.add_argument("--temperature", type=float, default=0.2) 121 | parser.add_argument("--max-new-tokens", type=int, default=512) 122 | parser.add_argument("--load-8bit", action="store_true") 123 | parser.add_argument("--load-4bit", action="store_true") 124 | parser.add_argument("--debug", action="store_true") 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /llava/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from llava.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,), daemon=True) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 299 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bigai-ai/QA-Synthesizer/29d62a52d91153734e6e28bab2f70e36dd969aaa/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | from fastapi import FastAPI, Request, BackgroundTasks 12 | from fastapi.responses import StreamingResponse 13 | import requests 14 | import torch 15 | import uvicorn 16 | from functools import partial 17 | 18 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 19 | from llava.utils import (build_logger, server_error_msg, 20 | pretty_print_semaphore) 21 | from llava.model.builder import load_pretrained_model 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token 23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | from transformers import TextIteratorStreamer 25 | from threading import Thread 26 | 27 | 28 | GB = 1 << 30 29 | 30 | worker_id = str(uuid.uuid4())[:6] 31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 32 | global_counter = 0 33 | 34 | model_semaphore = None 35 | 36 | 37 | def heart_beat_worker(controller): 38 | 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | class ModelWorker: 45 | def __init__(self, controller_addr, worker_addr, 46 | worker_id, no_register, 47 | model_path, model_base, model_name, 48 | load_8bit, load_4bit, device, use_flash_attn=False): 49 | self.controller_addr = controller_addr 50 | self.worker_addr = worker_addr 51 | self.worker_id = worker_id 52 | if model_path.endswith("/"): 53 | model_path = model_path[:-1] 54 | if model_name is None: 55 | model_paths = model_path.split("/") 56 | if model_paths[-1].startswith('checkpoint-'): 57 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 58 | else: 59 | self.model_name = model_paths[-1] 60 | else: 61 | self.model_name = model_name 62 | 63 | self.device = device 64 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 65 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( 66 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn) 67 | self.is_multimodal = 'llava' in self.model_name.lower() 68 | 69 | if not no_register: 70 | self.register_to_controller() 71 | self.heart_beat_thread = threading.Thread( 72 | target=heart_beat_worker, args=(self,), daemon=True) 73 | self.heart_beat_thread.start() 74 | 75 | def register_to_controller(self): 76 | logger.info("Register to controller") 77 | 78 | url = self.controller_addr + "/register_worker" 79 | data = { 80 | "worker_name": self.worker_addr, 81 | "check_heart_beat": True, 82 | "worker_status": self.get_status() 83 | } 84 | r = requests.post(url, json=data) 85 | assert r.status_code == 200 86 | 87 | def send_heart_beat(self): 88 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 89 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 90 | f"global_counter: {global_counter}") 91 | 92 | url = self.controller_addr + "/receive_heart_beat" 93 | 94 | while True: 95 | try: 96 | ret = requests.post(url, json={ 97 | "worker_name": self.worker_addr, 98 | "queue_length": self.get_queue_length()}, timeout=5) 99 | exist = ret.json()["exist"] 100 | break 101 | except requests.exceptions.RequestException as e: 102 | logger.error(f"heart beat error: {e}") 103 | time.sleep(5) 104 | 105 | if not exist: 106 | self.register_to_controller() 107 | 108 | def get_queue_length(self): 109 | if model_semaphore is None: 110 | return 0 111 | else: 112 | return args.limit_model_concurrency - model_semaphore._value + (len( 113 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 114 | 115 | def get_status(self): 116 | return { 117 | "model_names": [self.model_name], 118 | "speed": 1, 119 | "queue_length": self.get_queue_length(), 120 | } 121 | 122 | @torch.inference_mode() 123 | def generate_stream(self, params): 124 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 125 | 126 | prompt = params["prompt"] 127 | ori_prompt = prompt 128 | images = params.get("images", None) 129 | num_image_tokens = 0 130 | if images is not None and len(images) > 0 and self.is_multimodal: 131 | if len(images) > 0: 132 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 133 | raise ValueError("Number of images does not match number of tokens in prompt") 134 | 135 | images = [load_image_from_base64(image) for image in images] 136 | image_sizes = [image.size for image in images] 137 | images = process_images(images, image_processor, model.config) 138 | 139 | if type(images) is list: 140 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 141 | else: 142 | images = images.to(self.model.device, dtype=torch.float16) 143 | 144 | replace_token = DEFAULT_IMAGE_TOKEN 145 | if getattr(self.model.config, 'mm_use_im_start_end', False): 146 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 147 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 148 | 149 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches 150 | else: 151 | images = None 152 | image_sizes = None 153 | image_args = {"images": images, "image_sizes": image_sizes} 154 | else: 155 | images = None 156 | image_args = {} 157 | 158 | temperature = float(params.get("temperature", 1.0)) 159 | top_p = float(params.get("top_p", 1.0)) 160 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 161 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 162 | stop_str = params.get("stop", None) 163 | do_sample = True if temperature > 0.001 else False 164 | 165 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) 166 | keywords = [stop_str] 167 | # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 168 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 169 | 170 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 171 | 172 | if max_new_tokens < 1: 173 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 174 | return 175 | 176 | thread = Thread(target=model.generate, kwargs=dict( 177 | inputs=input_ids, 178 | do_sample=do_sample, 179 | temperature=temperature, 180 | top_p=top_p, 181 | max_new_tokens=max_new_tokens, 182 | streamer=streamer, 183 | use_cache=True, 184 | **image_args 185 | )) 186 | thread.start() 187 | 188 | generated_text = ori_prompt 189 | for new_text in streamer: 190 | generated_text += new_text 191 | if generated_text.endswith(stop_str): 192 | generated_text = generated_text[:-len(stop_str)] 193 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 194 | 195 | def generate_stream_gate(self, params): 196 | try: 197 | for x in self.generate_stream(params): 198 | yield x 199 | except ValueError as e: 200 | print("Caught ValueError:", e) 201 | ret = { 202 | "text": server_error_msg, 203 | "error_code": 1, 204 | } 205 | yield json.dumps(ret).encode() + b"\0" 206 | except torch.cuda.CudaError as e: 207 | print("Caught torch.cuda.CudaError:", e) 208 | ret = { 209 | "text": server_error_msg, 210 | "error_code": 1, 211 | } 212 | yield json.dumps(ret).encode() + b"\0" 213 | except Exception as e: 214 | print("Caught Unknown Error", e) 215 | ret = { 216 | "text": server_error_msg, 217 | "error_code": 1, 218 | } 219 | yield json.dumps(ret).encode() + b"\0" 220 | 221 | 222 | app = FastAPI() 223 | 224 | 225 | def release_model_semaphore(fn=None): 226 | model_semaphore.release() 227 | if fn is not None: 228 | fn() 229 | 230 | 231 | @app.post("/worker_generate_stream") 232 | async def generate_stream(request: Request): 233 | global model_semaphore, global_counter 234 | global_counter += 1 235 | params = await request.json() 236 | 237 | if model_semaphore is None: 238 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 239 | await model_semaphore.acquire() 240 | worker.send_heart_beat() 241 | generator = worker.generate_stream_gate(params) 242 | background_tasks = BackgroundTasks() 243 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 244 | return StreamingResponse(generator, background=background_tasks) 245 | 246 | 247 | @app.post("/worker_get_status") 248 | async def get_status(request: Request): 249 | return worker.get_status() 250 | 251 | 252 | if __name__ == "__main__": 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument("--host", type=str, default="localhost") 255 | parser.add_argument("--port", type=int, default=21002) 256 | parser.add_argument("--worker-address", type=str, 257 | default="http://localhost:21002") 258 | parser.add_argument("--controller-address", type=str, 259 | default="http://localhost:21001") 260 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 261 | parser.add_argument("--model-base", type=str, default=None) 262 | parser.add_argument("--model-name", type=str) 263 | parser.add_argument("--device", type=str, default="cuda") 264 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 265 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 266 | parser.add_argument("--stream-interval", type=int, default=1) 267 | parser.add_argument("--no-register", action="store_true") 268 | parser.add_argument("--load-8bit", action="store_true") 269 | parser.add_argument("--load-4bit", action="store_true") 270 | parser.add_argument("--use-flash-attn", action="store_true") 271 | args = parser.parse_args() 272 | logger.info(f"args: {args}") 273 | 274 | if args.multi_modal: 275 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 276 | 277 | worker = ModelWorker(args.controller_address, 278 | args.worker_address, 279 | worker_id, 280 | args.no_register, 281 | args.model_path, 282 | args.model_base, 283 | args.model_name, 284 | args.load_8bit, 285 | args.load_4bit, 286 | args.device, 287 | use_flash_attn=args.use_flash_attn) 288 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 289 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/sglang_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | from concurrent.futures import ThreadPoolExecutor 7 | import json 8 | import time 9 | import threading 10 | import uuid 11 | 12 | from fastapi import FastAPI, Request, BackgroundTasks 13 | from fastapi.responses import StreamingResponse 14 | import requests 15 | import re 16 | import uvicorn 17 | from functools import partial 18 | 19 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 20 | from llava.utils import (build_logger, server_error_msg, 21 | pretty_print_semaphore) 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square 23 | from llava.constants import DEFAULT_IMAGE_TOKEN 24 | 25 | import sglang as sgl 26 | from sglang.backend.runtime_endpoint import RuntimeEndpoint 27 | 28 | 29 | GB = 1 << 30 30 | 31 | worker_id = str(uuid.uuid4())[:6] 32 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 33 | global_counter = 0 34 | 35 | model_semaphore = None 36 | 37 | 38 | def heart_beat_worker(controller): 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | @sgl.function 45 | def pipeline(s, prompt, max_tokens): 46 | for p in prompt: 47 | if type(p) is str: 48 | s += p 49 | else: 50 | s += sgl.image(p) 51 | s += sgl.gen("response", max_tokens=max_tokens) 52 | 53 | 54 | class ModelWorker: 55 | def __init__(self, controller_addr, worker_addr, sgl_endpoint, 56 | worker_id, no_register, model_name): 57 | self.controller_addr = controller_addr 58 | self.worker_addr = worker_addr 59 | self.worker_id = worker_id 60 | 61 | # Select backend 62 | backend = RuntimeEndpoint(sgl_endpoint) 63 | sgl.set_default_backend(backend) 64 | model_path = backend.model_info["model_path"] 65 | 66 | if model_path.endswith("/"): 67 | model_path = model_path[:-1] 68 | if model_name is None: 69 | model_paths = model_path.split("/") 70 | if model_paths[-1].startswith('checkpoint-'): 71 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 72 | else: 73 | self.model_name = model_paths[-1] 74 | else: 75 | self.model_name = model_name 76 | 77 | logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...") 78 | 79 | if not no_register: 80 | self.register_to_controller() 81 | self.heart_beat_thread = threading.Thread( 82 | target=heart_beat_worker, args=(self,), daemon=True) 83 | self.heart_beat_thread.start() 84 | 85 | def register_to_controller(self): 86 | logger.info("Register to controller") 87 | 88 | url = self.controller_addr + "/register_worker" 89 | data = { 90 | "worker_name": self.worker_addr, 91 | "check_heart_beat": True, 92 | "worker_status": self.get_status() 93 | } 94 | r = requests.post(url, json=data) 95 | assert r.status_code == 200 96 | 97 | def send_heart_beat(self): 98 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 99 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 100 | f"global_counter: {global_counter}") 101 | 102 | url = self.controller_addr + "/receive_heart_beat" 103 | 104 | while True: 105 | try: 106 | ret = requests.post(url, json={ 107 | "worker_name": self.worker_addr, 108 | "queue_length": self.get_queue_length()}, timeout=5) 109 | exist = ret.json()["exist"] 110 | break 111 | except requests.exceptions.RequestException as e: 112 | logger.error(f"heart beat error: {e}") 113 | time.sleep(5) 114 | 115 | if not exist: 116 | self.register_to_controller() 117 | 118 | def get_queue_length(self): 119 | if model_semaphore is None: 120 | return 0 121 | else: 122 | return args.limit_model_concurrency - model_semaphore._value + (len( 123 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 124 | 125 | def get_status(self): 126 | return { 127 | "model_names": [self.model_name], 128 | "speed": 1, 129 | "queue_length": self.get_queue_length(), 130 | } 131 | 132 | async def generate_stream(self, params): 133 | ori_prompt = prompt = params["prompt"] 134 | images = params.get("images", None) 135 | if images is not None and len(images) > 0: 136 | if len(images) > 0: 137 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 138 | raise ValueError("Number of images does not match number of tokens in prompt") 139 | 140 | images = [load_image_from_base64(image) for image in images] 141 | 142 | # FIXME: for image-start/end token 143 | # replace_token = DEFAULT_IMAGE_TOKEN 144 | # if getattr(self.model.config, 'mm_use_im_start_end', False): 145 | # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 146 | # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 147 | prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN) 148 | prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN) 149 | prompt = [] 150 | for i in range(len(prompt_split)): 151 | prompt.append(prompt_split[i]) 152 | if i < len(images): 153 | prompt.append(images[i]) 154 | else: 155 | prompt = [prompt] 156 | 157 | temperature = float(params.get("temperature", 1.0)) 158 | top_p = float(params.get("top_p", 1.0)) 159 | # max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 160 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 161 | stop_str = params.get("stop", None) 162 | stop_str = [stop_str] if stop_str is not None else None 163 | 164 | print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p}) 165 | state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True) 166 | 167 | generated_text = ori_prompt 168 | async for text_outputs in state.text_async_iter(var_name="response"): 169 | generated_text += text_outputs 170 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 171 | 172 | async def generate_stream_gate(self, params): 173 | try: 174 | async for x in self.generate_stream(params): 175 | yield x 176 | except ValueError as e: 177 | print("Caught ValueError:", e) 178 | ret = { 179 | "text": server_error_msg, 180 | "error_code": 1, 181 | } 182 | yield json.dumps(ret).encode() + b"\0" 183 | except Exception as e: 184 | print("Caught Unknown Error", e) 185 | ret = { 186 | "text": server_error_msg, 187 | "error_code": 1, 188 | } 189 | yield json.dumps(ret).encode() + b"\0" 190 | 191 | 192 | app = FastAPI() 193 | 194 | 195 | def release_model_semaphore(fn=None): 196 | model_semaphore.release() 197 | if fn is not None: 198 | fn() 199 | 200 | 201 | @app.post("/worker_generate_stream") 202 | async def generate_stream(request: Request): 203 | global model_semaphore, global_counter 204 | global_counter += 1 205 | params = await request.json() 206 | 207 | if model_semaphore is None: 208 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 209 | await model_semaphore.acquire() 210 | worker.send_heart_beat() 211 | generator = worker.generate_stream_gate(params) 212 | background_tasks = BackgroundTasks() 213 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 214 | return StreamingResponse(generator, background=background_tasks) 215 | 216 | 217 | @app.post("/worker_get_status") 218 | async def get_status(request: Request): 219 | return worker.get_status() 220 | 221 | 222 | if __name__ == "__main__": 223 | parser = argparse.ArgumentParser() 224 | parser.add_argument("--host", type=str, default="localhost") 225 | parser.add_argument("--port", type=int, default=21002) 226 | parser.add_argument("--worker-address", type=str, 227 | default="http://localhost:21002") 228 | parser.add_argument("--controller-address", type=str, 229 | default="http://localhost:21001") 230 | parser.add_argument("--model-name", type=str) 231 | parser.add_argument("--sgl-endpoint", type=str) 232 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 233 | parser.add_argument("--stream-interval", type=int, default=1) 234 | parser.add_argument("--no-register", action="store_true") 235 | args = parser.parse_args() 236 | logger.info(f"args: {args}") 237 | 238 | worker = ModelWorker(args.controller_address, 239 | args.worker_address, 240 | args.sgl_endpoint, 241 | worker_id, 242 | args.no_register, 243 | args.model_name) 244 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 245 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("./") 3 | 4 | from llava.train.train import train 5 | 6 | if __name__ == "__main__": 7 | train(attn_implementation="flash_attention_2") 8 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "adamllm" 7 | version = "0.0.1" 8 | description = "Adapt multimodal large language models to domains" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | ] 14 | dependencies = [ 15 | "torch==2.1.2", "torchvision==0.16.2", 16 | "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 17 | "accelerate==0.21.0", "peft", "bitsandbytes", 18 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 19 | "gradio==4.16.0", "gradio_client==0.8.1", 20 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", "prettytable", 21 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "openpyxl" 22 | ] 23 | 24 | [project.optional-dependencies] 25 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 26 | build = ["build", "twine"] 27 | 28 | [tool.setuptools.packages.find] 29 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 30 | 31 | [tool.wheel] 32 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /scripts/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "biomed": { 3 | "file_name": "PATH_TO/biomed-visual-instructions/image_caption_and_synthetic_task.json", 4 | "formatting": "sharegpt", 5 | "columns": { 6 | "messages": "messages", 7 | "images": "images" 8 | }, 9 | "tags": { 10 | "role_tag": "role", 11 | "content_tag": "content", 12 | "user_tag": "user", 13 | "assistant_tag": "assistant" 14 | } 15 | }, 16 | "food": { 17 | "file_name": "PATH_TO/food-visual-instructions/image_caption_and_synthetic_task.json", 18 | "formatting": "sharegpt", 19 | "columns": { 20 | "messages": "messages", 21 | "images": "images" 22 | }, 23 | "tags": { 24 | "role_tag": "role", 25 | "content_tag": "content", 26 | "user_tag": "user", 27 | "assistant_tag": "assistant" 28 | } 29 | } 30 | } -------------------------------------------------------------------------------- /scripts/post_train_mllm.sh: -------------------------------------------------------------------------------- 1 | 2 | DOMAIN=$1 3 | DATASET=$2 4 | IMAGE_FOLDER=$3 5 | BATCH_SIZE=4 6 | GRADIENT_ACCU_STEPS=4 7 | SAVE_PATH=./exp/${DOMAIN}-LLaVA-v1.6-8B 8 | 9 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' python -m torch.distributed.launch --use_env --nproc_per_node=8 --master_port=12345 llava/train/train_mem.py \ 10 | --deepspeed ./scripts/zero3.json \ 11 | --model_name_or_path Lin-Chen/open-llava-next-llama3-8b \ 12 | --version llava_llama_3 \ 13 | --image_folder ${IMAGE_FOLDER} \ 14 | --vision_tower "/openai/clip-vit-large-patch14-336" \ 15 | --mm_projector_type mlp2x_gelu \ 16 | --unfreeze_mm_vision_tower True \ 17 | --mm_vision_tower_lr 2e-6 \ 18 | --image_aspect_ratio anyres \ 19 | --group_by_modality_length True \ 20 | --mm_vision_select_layer -2 \ 21 | --mm_vision_select_feature patch \ 22 | --mm_patch_merge_type spatial_unpad \ 23 | --mm_use_im_start_end False \ 24 | --mm_use_im_patch_token False \ 25 | --bf16 True \ 26 | --output_dir ${SAVE_PATH} \ 27 | --num_train_epochs 1 \ 28 | --per_device_train_batch_size ${BATCH_SIZE} \ 29 | --per_device_eval_batch_size 4 \ 30 | --gradient_accumulation_steps ${GRADIENT_ACCU_STEPS} \ 31 | --evaluation_strategy "no" \ 32 | --save_strategy "steps" \ 33 | --save_steps 500 \ 34 | --save_total_limit 1 \ 35 | --learning_rate 2e-5 \ 36 | --weight_decay 0. \ 37 | --warmup_ratio 0.03 \ 38 | --lr_scheduler_type "cosine" \ 39 | --logging_steps 10 \ 40 | --tf32 True \ 41 | --model_max_length 6144 \ 42 | --gradient_checkpointing True \ 43 | --dataloader_num_workers 16 \ 44 | --lazy_preprocess True \ 45 | --report_to none \ 46 | --run_name ${SAVE_PATH} \ 47 | --data_path ${DATASET} -------------------------------------------------------------------------------- /scripts/tune_synthesizer.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | CAPTION=$1 # PATH_TO/ALLaVA-Caption-VFLAN-4V.json 4 | PRECISE_A=$2 # PATH_TO/vflan_metadata.json 5 | INFORMATIVE_A=$3 #PATH_TO/ALLaVA-Instruct-VFLAN-4V.json 6 | IAMGE_FOLDER=$4 # PATH_TO/images_191task_1k 7 | 8 | BLANK_IMAGE=./Blank.jpg 9 | BATCH_SIZE=4 10 | GRADIENT_ACCU_STEPS=4 11 | SAVE_PATH=./exp/synthesizer 12 | 13 | CUDA_VISIBLE_DEVICES='0,1,2,3,4,5,6,7' python -m torch.distributed.launch --use_env --nproc_per_node=8 --master_port=12345 llava/train/train_mem.py \ 14 | --deepspeed "./scripts/zero3.json" \ 15 | --model_name_or_path Lin-Chen/open-llava-next-llama3-8b \ 16 | --version llava_llama_3 \ 17 | --image_folder ${IAMGE_FOLDER} \ 18 | --vision_tower openai/clip-vit-large-patch14-336 \ 19 | --mm_projector_type mlp2x_gelu \ 20 | --unfreeze_mm_vision_tower True \ 21 | --mm_vision_tower_lr 2e-6 \ 22 | --image_aspect_ratio anyres \ 23 | --group_by_modality_length True \ 24 | --mm_vision_select_layer -2 \ 25 | --mm_vision_select_feature patch \ 26 | --mm_patch_merge_type spatial_unpad \ 27 | --mm_use_im_start_end False \ 28 | --mm_use_im_patch_token False \ 29 | --bf16 True \ 30 | --output_dir ${SAVE_PATH} \ 31 | --num_train_epochs 2 \ 32 | --per_device_train_batch_size ${BATCH_SIZE} \ 33 | --per_device_eval_batch_size 4 \ 34 | --gradient_accumulation_steps ${GRADIENT_ACCU_STEPS} \ 35 | --evaluation_strategy "no" \ 36 | --save_strategy "steps" \ 37 | --save_steps 500 \ 38 | --save_total_limit 1 \ 39 | --learning_rate 2e-5 \ 40 | --weight_decay 0. \ 41 | --warmup_ratio 0.03 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 10 \ 44 | --tf32 True \ 45 | --model_max_length 6144 \ 46 | --gradient_checkpointing True \ 47 | --dataloader_num_workers 4 \ 48 | --lazy_preprocess True \ 49 | --report_to none \ 50 | --run_name ${SAVE_PATH} \ 51 | --syn_mode 'precise+informative' \ 52 | --replace_with_blank_image_percent 10 \ 53 | --replace_with_blank_image_percent \ 54 | --caption_path ${CAPTION} \ 55 | --precise_qa_path ${PRECISE_A} \ 56 | --informative_qa_path ${INFORMATIVE_A} \ 57 | --blank_image_path ${BLANK_IMAGE} -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /vllm_inference/eval_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from utils.task import task_map 4 | from tqdm import tqdm 5 | import os 6 | import jsonlines 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--output_dir', type=str, help='directory for saving the shards.') 10 | parser.add_argument('--model_path', type=str) 11 | parser.add_argument('--task_name', type=str) 12 | parser.add_argument('--stop',type=int, default=None) 13 | parser.add_argument('--model_type', choices = ['llava', 'qwen2_vl', 'mllama', 'llama']) 14 | parser.add_argument('--eval_results_dir', type=str) 15 | 16 | args = parser.parse_args() 17 | 18 | 19 | if args.stop is not None: 20 | print(f'debug mode...') 21 | args.output_dir='/tmp/test_syn' 22 | 23 | print(f'args.task_name: {args.task_name}') 24 | task_cls = task_map.cls_dic[args.task_name](args.model_type) 25 | 26 | if task_cls.enable_eval: 27 | metadata_list = [] 28 | out_path = f'{args.output_dir}.jsonl' 29 | with open(out_path, 'r', encoding='utf8') as f: 30 | jsonls = f.read().strip().split('\n') 31 | for jsonl in tqdm(jsonls): 32 | metadata_list.append(json.loads(jsonl)) 33 | 34 | print(f'eval: {out_path}') 35 | metadata_list, results = task_cls.evaluate(metadata_list, stop=args.stop) 36 | results_file = os.path.join(args.eval_results_dir, f'{args.task_name}.txt') 37 | os.makedirs(os.path.dirname(results_file), exist_ok=True) 38 | 39 | with open(results_file, "a") as f: 40 | info = {'pred_file': out_path, 41 | 'model': str(args.model_path), 42 | 'scores': results} 43 | f.write(json.dumps(info, indent=2) + "\n") 44 | print(json.dumps(info, indent=2) + "\n") 45 | print(f'write results to: {results_file}') 46 | 47 | # re-save metadatalist to save the score for each individual entry 48 | print(f're-saving scored metadatalist to {out_path}...') 49 | os.remove(out_path) 50 | with jsonlines.open(out_path,mode='a') as writer: 51 | for doc in metadata_list: 52 | writer.write(doc) 53 | print(f'saved jsonl to: {out_path}') 54 | 55 | print('done') -------------------------------------------------------------------------------- /vllm_inference/format_post_train_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | convert the data after the consistency-filter into our single-stage post-train format 3 | """ 4 | import sys 5 | sys.path.append("../") 6 | import process.syn_utils as syn_utils 7 | import argparse 8 | import os 9 | from tqdm import tqdm 10 | from PIL import Image 11 | import json 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--filtered_task_pairs', type=str, default='') 15 | parser.add_argument('--image_folder', type=str, default='') 16 | parser.add_argument('--output_path', type=str) 17 | parser.add_argument('--train_on_syn_only', action='store_true', help='whether to remove the image-captioning task.') 18 | parser.add_argument('--stop',type=int, default=None) 19 | 20 | args = parser.parse_args() 21 | 22 | if args.stop is not None: 23 | print(f'debug mode...') 24 | args.filtered_task_pairs='/tmp/test_syn.jsonl' 25 | args.output_path='/tmp/test_train_data.json' 26 | 27 | ds = [] 28 | with open(args.filtered_task_pairs, 'r', encoding='utf8') as f: 29 | jsonls = f.read().strip().split('\n') 30 | for jsonl in tqdm(jsonls): 31 | ds.append(json.loads(jsonl)) 32 | 33 | id = 0 34 | clean_ds = [] 35 | for entry in tqdm(ds): 36 | entry = syn_utils.process_entry(id=id, entry=entry, image_token='', train_on_syn_only=args.train_on_syn_only) 37 | if entry is None: 38 | continue 39 | 40 | image_file = entry["images"][0] 41 | try: 42 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 43 | except Exception as e: 44 | print(e) 45 | continue 46 | clean_ds.append(entry) 47 | id += 1 48 | 49 | # reformat to llama-factory's style 50 | """ 51 | factory example: 52 | { 53 | "messages": [ 54 | { 55 | "content": "Who are they?", 56 | "role": "user" 57 | }, 58 | { 59 | "content": "They're Kane and Gretzka from Bayern Munich.", 60 | "role": "assistant" 61 | } 62 | ], 63 | "images": [ 64 | "mllm_demo_data/1.jpg" 65 | ] 66 | } 67 | """ 68 | def reformat(entry): 69 | new_entry = {} 70 | # we place the image input at the beginning of instruction 71 | assert entry['conversations'][0]['value'].startswith('\n') 72 | first_message = { 73 | "content": entry['conversations'][0]['value'], 74 | "role": "user" 75 | } 76 | new_entry['messages'] = [first_message] 77 | for ex in entry['conversations'][1:]: 78 | assert '' not in ex['value'] 79 | new_m = { 80 | 'content': ex['value'], 81 | 'role': "user" if ex['from'] == "human" else "assistant" 82 | } 83 | new_entry['messages'].append(new_m) 84 | return new_entry 85 | 86 | factory_data = [reformat(entry) for entry in tqdm(clean_ds)] 87 | 88 | syn_utils.save_json(factory_data, args.output_path) -------------------------------------------------------------------------------- /vllm_inference/inference.py: -------------------------------------------------------------------------------- 1 | from vllm import LLM, SamplingParams 2 | import argparse 3 | import os 4 | from utils.cache_util import BufferedJsonWriter 5 | from utils.task import task_map 6 | from tqdm import tqdm 7 | from more_itertools import distribute 8 | from transformers import AutoProcessor 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--model_weight_path', type=str) 12 | parser.add_argument('--data_path', type=str, default=None) 13 | parser.add_argument('--image_folder', type=str, default=None) 14 | parser.add_argument('--output_dir', type=str) 15 | parser.add_argument('--stop',type=int, default=None) 16 | parser.add_argument('--remove_cache', action='store_true', help='remove the cached files') 17 | parser.add_argument('--data_parallel_size', type=int, default=1) 18 | parser.add_argument('--cuda_device', type=int, default=0) 19 | parser.add_argument('--task_name', type=str) 20 | parser.add_argument('--model_type', choices = ['llava', 'qwen2_vl', 'mllama', 'llama']) 21 | 22 | args = parser.parse_args() 23 | task_cls = task_map.cls_dic[args.task_name](args.model_type) 24 | 25 | #==== debug task args === 26 | print(f'### task_name: {args.task_name}') 27 | print(f'task_cls.model_type: {task_cls.model_type}') 28 | print(f'task_cls.stop_tokens: {task_cls.stop_tokens}') 29 | print(f'task_cls.max_tokens: {task_cls.max_tokens}') 30 | print(f'task_cls.skip_special_tokens: {task_cls.skip_special_tokens}') 31 | print(f'task_cls.max_model_len: {task_cls.max_model_len}') 32 | 33 | sampling_params = SamplingParams(temperature=0, max_tokens=task_cls.max_tokens, skip_special_tokens=task_cls.skip_special_tokens, stop=task_cls.stop_tokens) 34 | 35 | ds = task_cls.get_dataset(data_path=args.data_path, image_folder=args.image_folder, stop=args.stop) 36 | 37 | processor = None 38 | process_vision_info = None 39 | if args.model_type == 'llava': 40 | llm = LLM(model=args.model_weight_path, max_model_len=task_cls.max_model_len) 41 | elif args.model_type == 'qwen2_vl': 42 | from qwen_vl_utils import process_vision_info 43 | llm = LLM(model=args.model_weight_path, max_model_len=task_cls.max_model_len) 44 | processor = AutoProcessor.from_pretrained(args.model_weight_path) 45 | elif args.model_type == 'mllama': 46 | # add constraints to avoid oom for mllama-11b, ref: https://github.com/vllm-project/vllm/issues/9163#issuecomment-2400778274 47 | llm = LLM(model=args.model_weight_path, max_model_len=task_cls.max_model_len, dtype="bfloat16", gpu_memory_utilization=0.85, enforce_eager=True, max_num_seqs=20) 48 | processor = AutoProcessor.from_pretrained(args.model_weight_path) 49 | elif args.model_type == 'llama': 50 | llm = LLM(model=args.model_weight_path, max_model_len=task_cls.max_model_len) 51 | 52 | # assign identical index 53 | for id, entry in enumerate(ds): 54 | entry['syn_id'] = id 55 | 56 | if args.stop is not None: 57 | print(f'debug mode...') 58 | args.output_dir='/tmp/test_syn' 59 | ds = ds[:args.stop] 60 | 61 | # We pass every 500 entries to llm.generate and write them to the cache 62 | chunk_size = 500 if args.stop is None else 500 63 | 64 | def run_non_ddp_inference_one_model(split, rank=0): 65 | print(f'cur rank: {rank}, infer on {len(split)} prompts') 66 | cached_file_path = f"{args.output_dir}tmp_process-{rank}.bin" 67 | os.makedirs(os.path.dirname(cached_file_path), exist_ok=True) 68 | print(f'cur rank: {rank}, cached_file_path: {cached_file_path}') 69 | 70 | with BufferedJsonWriter(cached_file_path, buffer_size=1) as buffer: 71 | cached_size = 0 72 | if os.path.exists(cached_file_path): 73 | if args.remove_cache: 74 | os.remove(cached_file_path) 75 | print(f"cur rank: {rank}, {cached_file_path} removed successfully") 76 | else: 77 | cached_size = buffer.get_cached_size() 78 | print(f'cur rank: {rank}, continue from {cached_file_path}') 79 | print(f'cur rank: {rank}, cached_size = {cached_size}...') 80 | 81 | assert cached_size % chunk_size == 0, f'cur rank: {rank}, we save the outputs every chunk_size, so the cached_size should be multiple of chunk_size' 82 | 83 | silent = False if rank == 0 else True # we only show progress on rank 0 84 | for start_index in tqdm(range(0, len(split), chunk_size), disable = silent): 85 | if start_index < cached_size: continue 86 | cur_split = split[start_index: start_index + chunk_size] 87 | cur_prompts = [task_cls.get_prompt(line, args.stop, silent=True, processor=processor, process_vision_info=process_vision_info) for line in cur_split] 88 | cur_prompts = [entry for entry in cur_prompts if entry is not None] 89 | try: 90 | outputs = llm.generate(cur_prompts, sampling_params, use_tqdm=False) 91 | except Exception as e: 92 | print(e) 93 | buffer.write([]) 94 | continue 95 | metadata_list = [] 96 | id = 0 97 | for metadata in cur_split: 98 | task_prompt = task_cls.get_prompt(metadata, silent=True, processor=processor, process_vision_info=process_vision_info) 99 | if task_prompt is None: 100 | metadata.update({'pred': None}) 101 | if 'image' in metadata and not isinstance(metadata['image'], str): 102 | # To avoid the `TypeError: Object of type PngImageFile is not JSON serializable` when saving the data 103 | metadata.pop('image') 104 | metadata_list.append(metadata) 105 | continue 106 | output = outputs[id] 107 | if metadata['syn_id'] % 1000==0 or args.stop is not None: 108 | # For debugging, print input and output details every 1000 examples or when 'stop' is triggered. 109 | task_cls.get_prompt(metadata, args.stop, silent=False, processor=processor, process_vision_info=process_vision_info) 110 | print(f'pred: {output.outputs[0].text}', flush=True) 111 | id += 1 112 | 113 | if args.model_type in ['llava', 'qwen_vl', 'mllama'] and output.prompt[-10:] != task_prompt['prompt'][-10:]: 114 | print(f'output.prompt: {output.prompt[-10:]} does not fit for task_prompt: {task_prompt["prompt"][-10:]}') 115 | metadata.update({'pred': None}) 116 | metadata_list.append(metadata) 117 | continue 118 | if args.model_type == 'llama' and output.prompt[-10:] != task_prompt[-10:]: 119 | print(f'output.prompt: {output.prompt[-10:]} does not fit for task_prompt: {task_prompt[-10:]}') 120 | metadata.update({'pred': None}) 121 | metadata_list.append(metadata) 122 | continue 123 | metadata.update({'pred': output.outputs[0].text}) 124 | 125 | if 'image' in metadata and not isinstance(metadata['image'], str): 126 | # To avoid the `TypeError: Object of type PngImageFile is not JSON serializable` when saving the data 127 | metadata.pop('image') 128 | metadata_list.append(metadata) 129 | # we set the buffere_size = 1, so `write_outputs_to_cache_path` would happen every time we call buffer.write(outputs) 130 | # see ./utils/cache_util.py 131 | assert id == len(cur_prompts) and id == len(outputs), f'id: {id} != len(cur_prompts): {len(cur_prompts)} != len(outputs): {len(outputs)}' 132 | buffer.write(metadata_list) 133 | print(f'cur rank: {rank}, saved all the outputs to {cached_file_path}') 134 | 135 | 136 | if args.data_parallel_size > 1: 137 | print('ddp_mode...') 138 | # dispatch requests to all self.data_parallel_size workers, in interleaved fashion 139 | # interleaved important to balance context lengths across workers 140 | sharded_splits = [list(x) for x in distribute(args.data_parallel_size, ds)] 141 | run_non_ddp_inference_one_model(sharded_splits[args.cuda_device], rank=args.cuda_device) 142 | else: 143 | print('non_ddp_mode...') 144 | run_non_ddp_inference_one_model(ds) 145 | 146 | print('generating done') -------------------------------------------------------------------------------- /vllm_inference/merge_predictions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from utils.cache_util import BufferedJsonReader 4 | import jsonlines 5 | import glob 6 | 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--output_dir', type=str) 9 | parser.add_argument('--stop',type=int, default=None) 10 | 11 | args = parser.parse_args() 12 | 13 | if args.stop is not None: 14 | print(f'debug mode...') 15 | args.output_dir='/tmp/test_syn' 16 | 17 | # save all the results 18 | data = [] 19 | for path in sorted(glob.glob(f"{args.output_dir}tmp_*.bin")): 20 | with BufferedJsonReader(path) as f: 21 | for x in f.read(): 22 | data.extend(x) 23 | data.sort(key = lambda entry: entry['syn_id']) # sort the data by entry_id 24 | print(f"num of saved preds: {len(data)}") 25 | 26 | out_path = f'{args.output_dir}.jsonl' 27 | print(f'saving synthesized corpora to {out_path}') 28 | if os.path.exists(out_path): 29 | os.remove(out_path) 30 | print(f"cached {out_path} removed successfully") 31 | with jsonlines.open(out_path,mode='a') as writer: 32 | for doc in data: 33 | writer.write(doc) 34 | print(f'saved jsonl to: {out_path}') 35 | 36 | for path in glob.glob(f"{args.output_dir}tmp_*.bin"): 37 | os.remove(path) 38 | 39 | print('done') 40 | -------------------------------------------------------------------------------- /vllm_inference/run_inference.sh: -------------------------------------------------------------------------------- 1 | 2 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 3 | IFS=',' read -ra GPULIST <<< "$gpu_list" 4 | CHUNKS=${#GPULIST[@]} 5 | CKPT=$1 6 | DOMAIN=$2 7 | MODEL_TYPE=$3 8 | OUTPUT_DIR=$4 9 | RESULTS_DIR=$5 10 | OTHER_OPT=$6 # For debugging: you can append the '--stop 10' flag to watch all the intermediate input/output of the first 10 data examples 11 | 12 | if [ ${DOMAIN} == 'med' ]; then 13 | TASK_array=( 14 | 'SLAKE' 15 | "PathVQA" 16 | 'VQA_RAD' 17 | "PMC_VQA" 18 | ) 19 | elif [ ${DOMAIN} == 'food' ]; then 20 | TASK_array=( 21 | "Recipe1M" 22 | "Nutrition5K" 23 | "FoodSeg103" 24 | "Food101" 25 | ) 26 | else 27 | TASK_array=( 28 | ${DOMAIN} 29 | ) 30 | fi 31 | 32 | echo "Prepare Code for Domain: ${DOMAIN}, Model type: ${MODEL_TYPE}" 33 | 34 | for j in "${!TASK_array[@]}"; do 35 | TASK=${TASK_array[j]} 36 | echo "TASK: ${TASK}" 37 | echo "OUTPUT_DIR: ${OUTPUT_DIR}/${TASK}" # save outputs for every single task 38 | for IDX in $(seq 0 $((CHUNKS-1))); do 39 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python inference.py \ 40 | --model_weight_path ${CKPT} \ 41 | --task ${TASK} \ 42 | --cuda_device ${GPULIST[$IDX]} \ 43 | --output_dir ${OUTPUT_DIR}/${TASK} \ 44 | --remove_cache \ 45 | --data_parallel_size ${CHUNKS} \ 46 | --model_type ${MODEL_TYPE} ${OTHER_OPT} & 47 | done 48 | 49 | wait 50 | 51 | echo 'inference done' 52 | 53 | python merge_predictions.py \ 54 | --output_dir ${OUTPUT_DIR}/${TASK} \ 55 | ${OTHER_OPT} 56 | 57 | python eval_predictions.py \ 58 | --output_dir ${OUTPUT_DIR}/${TASK} \ 59 | --model_path ${CKPT} \ 60 | --task_name ${TASK} \ 61 | --model_type ${MODEL_TYPE} \ 62 | --eval_results_dir ${RESULTS_DIR} \ 63 | ${OTHER_OPT} 64 | done -------------------------------------------------------------------------------- /vllm_inference/run_synthesis.sh: -------------------------------------------------------------------------------- 1 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 2 | IFS=',' read -ra GPULIST <<< "$gpu_list" 3 | CHUNKS=${#GPULIST[@]} 4 | SYNTHESIZER=$1 # AdaptLLM/visual-instruction-synthesizer 5 | CONSISTENCY_CHECKER=$2 # meta-llama/Meta-Llama-3-8B 6 | IMAGE_CAPTION=$3 # Path to the json file of image_caption pairs (in the ShareGPT format) 7 | IMAGE_FOLDER=$4 # Path to the image folder 8 | OUTPUT_DIR=$5 9 | OTHER_OPT=$6 # you can add the "--stop 10" flag to watch all the intermediate input/output of the first 10 data examples 10 | 11 | # 1. Synthesize `instruction-informative response-precise response` triplets 12 | echo "Synthesizing task triplets..." 13 | echo "Output path: ${OUTPUT_DIR}/syn_task_triplets.jsonl" 14 | for IDX in $(seq 0 $((CHUNKS-1))); do 15 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python inference.py \ 16 | --model_weight_path ${SYNTHESIZER} \ 17 | --task syn_task_triplet \ 18 | --data_path ${IMAGE_CAPTION} \ 19 | --image_folder ${IMAGE_FOLDER} \ 20 | --cuda_device ${GPULIST[$IDX]} \ 21 | --output_dir ${OUTPUT_DIR}/syn_task_triplets \ 22 | --remove_cache \ 23 | --data_parallel_size ${CHUNKS} \ 24 | --model_type 'llava' ${OTHER_OPT} & 25 | done 26 | 27 | wait 28 | 29 | python merge_predictions.py \ 30 | --output_dir ${OUTPUT_DIR}/syn_task_triplets \ 31 | ${OTHER_OPT} 32 | 33 | echo 'Synthesis done' 34 | 35 | # 2. Consistency-based filter 36 | echo "Conducting consistency-based filtering..." 37 | echo "Output path: ${OUTPUT_DIR}/filtered_task_pairs.jsonl" 38 | for IDX in $(seq 0 $((CHUNKS-1))); do 39 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python inference.py \ 40 | --model_weight_path ${CONSISTENCY_CHECKER} \ 41 | --task consistency_filter \ 42 | --data_path ${OUTPUT_DIR}/syn_task_triplets.jsonl \ 43 | --image_folder ${IMAGE_FOLDER} \ 44 | --cuda_device ${GPULIST[$IDX]} \ 45 | --output_dir ${OUTPUT_DIR}/filtered_task_pairs \ 46 | --remove_cache \ 47 | --data_parallel_size ${CHUNKS} \ 48 | --model_type 'llama' ${OTHER_OPT} & 49 | done 50 | 51 | wait 52 | 53 | python merge_predictions.py \ 54 | --output_dir ${OUTPUT_DIR}/filtered_task_pairs \ 55 | ${OTHER_OPT} 56 | 57 | echo 'Filter done' 58 | 59 | python format_post_train_data.py \ 60 | --filtered_task_pairs ${OUTPUT_DIR}/filtered_task_pairs.jsonl \ 61 | --image_folder ${IMAGE_FOLDER} \ 62 | --output_path ${OUTPUT_DIR}/image_caption_and_synthetic_task.json \ 63 | ${OTHER_OPT} 64 | 65 | echo "Single-stage training data saved to: ${OUTPUT_DIR}/image_caption_and_synthetic_task.json" -------------------------------------------------------------------------------- /vllm_inference/utils/cache_util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pathlib 3 | 4 | class BufferedJsonWriter(object): 5 | def __init__(self, file_name, buffer_size=25): 6 | self.file_path = file_name 7 | self.buffer = [] 8 | self.buffer_size = buffer_size 9 | 10 | def __enter__(self): 11 | return self 12 | 13 | def __exit__(self, type, value, traceback): 14 | if len(self.buffer) > 0: 15 | # requires > 0 to avoid too many empty lines in the 16 | self.write_buffer() 17 | 18 | def write(self, obj=None): 19 | if obj is not None: 20 | self.buffer.append(obj) 21 | if len(self.buffer)>=self.buffer_size: 22 | self.write_buffer() 23 | 24 | def write_buffer(self): 25 | with open(self.file_path, "a") as data_file: 26 | data_file.write(json.dumps(self.buffer)) # ensure_ascii=False 27 | data_file.write("\n") 28 | self.buffer = [] 29 | 30 | def get_cached_size(self): 31 | """load from cached file, and skip infering on the entries already saved""" 32 | cached_size = 0 33 | with open(self.file_path, "r") as data_file: 34 | for line in data_file: 35 | l = json.loads(line) 36 | assert len(l) == 1 and isinstance(l[0], list), 'each line of the cache file should be like [[{"text": xxxx, "xxx": xxx}...]]' 37 | cached_size += len(l[0]) 38 | return cached_size 39 | 40 | 41 | class BufferedJsonReader(object): 42 | def __init__(self, file_name): 43 | self.file_path = file_name 44 | 45 | def __enter__(self): 46 | return self 47 | 48 | def __exit__(self, type, value, traceback): 49 | pass 50 | 51 | def __itr__(self): 52 | with open(self.file_path, "r") as data_file: 53 | for line in data_file: 54 | try: 55 | yield from json.loads(line) 56 | except Exception as e: 57 | # to skip broken json dict 58 | print(e) 59 | print(f'broken line:\n{line}') 60 | pass 61 | 62 | def read(self): 63 | return list(self.__itr__()) 64 | -------------------------------------------------------------------------------- /vllm_inference/utils/consistency_filter_prompt.txt: -------------------------------------------------------------------------------- 1 | I have an image-related question for which I have received two answers: one is informative, and the other is precise. Please help me assess whether the precise answer is consistent with the informative answer. 2 | 3 | Below are some examples. For each example, write your judgement as follows: 4 | - Yes: Consistent, meaning the precise answer can be inferred from the informative answer. 5 | - No: Inconsistent, meaning the precise answer cannot be inferred from the informative answer. 6 | - Open: Indicates that the question invites multiple interpretations or that answers can be expressed through open-ended phrasing, acknowledging complexity, or encouraging personal perspectives. It is also applicable to questions that require image description, captions, or background knowledge. 7 | 8 | ## Question: Given an image of something from the kitchen, identify what it could be. The image could be of cooking tools or items that are used for eating. It could also be used for serving food or storing it.\nOptions: (a) This image contains a cup (b) This image contains a bottle (c) This image contains a knife (d) This image contains a wine glass (e) This image contains a spoon (f) This image contains a bowl (g) This image contains a fork 9 | ## Informative Answer: The options provided relate to identifying kitchenware or items related to food service. From the given list, the most clearly identifiable object in the image is a plastic bottle with a blue label, which is visible in the background behind the dishes. No cups, knives, wine glasses, spoons, bowls (in the sense of a single, empty bowl), or forks are distinctly visible. Therefore, the correct identification from the provided options is:\n(b) This image contains a bottle. 10 | ## Precise Answer: (a) This image contains a cup 11 | ## Consistent: No 12 | 13 | ## Question: You are provided with an image which contains two pictures side by side. Your task is to identify the differences between the two pictures. Separate the differences with a comma. 14 | ## Informative Answer: Upon close inspection, the differences between the two images are subtle but distinct. The shadows cast by the trees and lamp posts differ in length and direction, suggesting a slight variation in the time or position of the sun when the photos were taken. The arrangement and presence of cars vary\u2014some are present in one image but not in the other. Architectural details such as windows, roof lines, and the design on the building exteriors also show discrepancies. The color of the roof on the small central structure is different in each picture, with one appearing darker than the other. Lastly, the patterns on the sidewalks, especially the crosswalk lines, are not identical, indicating a variation in either the perspective or actual changes in the environment. 15 | ## Precise Answer: the redish brown car moves slightly. 16 | ## Consistent: Open 17 | 18 | ## Question: Given an image of an animal, identify the kind of animal in the image. The picture could be of more popular animals that are visible around zoos or are sometimes domesticated at home. They could also sometimes be found in the wild.\nOptions: (a) This image contains a zebra (b) This image contains a cow (c) This image contains a horse (d) This image contains an elephant (e) This image contains a giraffe (f) This image contains a sheep (g) This image contains a cat (h) This image contains a bear (i) This image contains a bird (j) This image contains a dog 19 | ## Informative Answer: The animal in the image can be identified as a dog, specifically a golden retriever. This determination is based on several visual cues such as the golden coat, the size and shape of the ears, the facial structure, and the overall body size, which are characteristic of this breed. Golden retrievers are a popular domesticated breed known for their friendly and intelligent demeanor. They are commonly kept as pets and are not typically found in zoos or in the wild. The image does not contain any of the other animals listed in the options (a) through (i). 20 | ## Precise Answer: (j) This image contains a dog 21 | ## Consistent: Yes 22 | 23 | ## Question: Question: Is the picture taken outdoors?\nPlease answer the question by analyzing the scene in the provided image. 24 | ## Informative Answer: Analyzing the elements within the image, there are several indicators confirming that the picture is taken outdoors. The presence of natural elements like grass, trees, and the open sky, along with architectural features that are not enclosed, such as the adobe-style building with exterior walls and windows, support this observation. The lighting and shadows cast by the trees are consistent with natural sunlight, further corroborating the outdoor setting. Moreover, the materials used in the path and bench are typical of outdoor furnishings, designed to withstand weather conditions. All these aspects contribute to the conclusion that the photograph was taken in an exterior space, likely a garden or park associated with the building in the background. 25 | ## Precise Answer: yes 26 | ## Consistent: Yes 27 | 28 | ## Question: This task tests your ability to count number of objects. Here is the question \"How many chairs are in the photo?\" 29 | ## Informative Answer: Upon careful observation of the photo, each individual is seated in a distinct chair. The positioning of the subjects and their engagement with the camera suggests they are occupying all the chairs present at their table. There are no other chairs discernible in the immediate vicinity of the table or in the background. Therefore, it can be concluded that there are exactly four chairs in the photo, corresponding to the four people seen around the table. 30 | ## Precise Answer: two 31 | ## Consistent: No 32 | 33 | ## Question: Can you give me a caption and some background knowledge about this image? 34 | ## Informative Answer: A detailed caption for this image could read: \"A moment of reflection and accomplishment, as a retired NBA All-Star sits with the Championship Trophy, surrounded by peers and the spirit of the game.\" The image is probably taken at an NBA event dedicated to honoring the achievements of past basketball legends. The trophy next to the man symbolizes the pinnacle of success in the league, suggesting that he may be one of the accomplished players who has reached this level of success. The attire of the individuals indicates a formal celebration of basketball history, and the relaxed atmosphere hints at a session of storytelling or interviews about their experiences in the sport.\" 35 | ## Precise Answer: Moses Malone was selected by the New Orleans Jazz with the first pick. On December 9, 1975, the NBA planned to host a supplementary draft to settle negotiating rights to five ABA players who had never been eligible for the NBA draft because their college classes had not graduated and they had not apply for hardship. The teams selected in reverse order of their winloss record in the previous season. The team that made a selection must withdraw their equivalent selection in the 1976 Draft. The teams were allowed to not exercise their rights on this hardship draft and thus retained their full selection in the 1976 Draft. The draft itself attracted strong opposition from the ABA who accuse the NBA trying to reduce confidence in the stability of their league. Despite the initial postponement of the draft, the draft was finally held on December 30, 1975. 36 | ## Consistent: Open 37 | 38 | ## Question: {Q} 39 | ## Informative Answer: {informative_A} 40 | ## Precise Answer: {precise_A} 41 | ## Consistent: -------------------------------------------------------------------------------- /vllm_inference/utils/food101_name_to_label_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "apple_pie": 0, 3 | "baby_back_ribs": 1, 4 | "baklava": 2, 5 | "beef_carpaccio": 3, 6 | "beef_tartare": 4, 7 | "beet_salad": 5, 8 | "beignets": 6, 9 | "bibimbap": 7, 10 | "bread_pudding": 8, 11 | "breakfast_burrito": 9, 12 | "bruschetta": 10, 13 | "caesar_salad": 11, 14 | "cannoli": 12, 15 | "caprese_salad": 13, 16 | "carrot_cake": 14, 17 | "ceviche": 15, 18 | "cheesecake": 16, 19 | "cheese_plate": 17, 20 | "chicken_curry": 18, 21 | "chicken_quesadilla": 19, 22 | "chicken_wings": 20, 23 | "chocolate_cake": 21, 24 | "chocolate_mousse": 22, 25 | "churros": 23, 26 | "clam_chowder": 24, 27 | "club_sandwich": 25, 28 | "crab_cakes": 26, 29 | "creme_brulee": 27, 30 | "croque_madame": 28, 31 | "cup_cakes": 29, 32 | "deviled_eggs": 30, 33 | "donuts": 31, 34 | "dumplings": 32, 35 | "edamame": 33, 36 | "eggs_benedict": 34, 37 | "escargots": 35, 38 | "falafel": 36, 39 | "filet_mignon": 37, 40 | "fish_and_chips": 38, 41 | "foie_gras": 39, 42 | "french_fries": 40, 43 | "french_onion_soup": 41, 44 | "french_toast": 42, 45 | "fried_calamari": 43, 46 | "fried_rice": 44, 47 | "frozen_yogurt": 45, 48 | "garlic_bread": 46, 49 | "gnocchi": 47, 50 | "greek_salad": 48, 51 | "grilled_cheese_sandwich": 49, 52 | "grilled_salmon": 50, 53 | "guacamole": 51, 54 | "gyoza": 52, 55 | "hamburger": 53, 56 | "hot_and_sour_soup": 54, 57 | "hot_dog": 55, 58 | "huevos_rancheros": 56, 59 | "hummus": 57, 60 | "ice_cream": 58, 61 | "lasagna": 59, 62 | "lobster_bisque": 60, 63 | "lobster_roll_sandwich": 61, 64 | "macaroni_and_cheese": 62, 65 | "macarons": 63, 66 | "miso_soup": 64, 67 | "mussels": 65, 68 | "nachos": 66, 69 | "omelette": 67, 70 | "onion_rings": 68, 71 | "oysters": 69, 72 | "pad_thai": 70, 73 | "paella": 71, 74 | "pancakes": 72, 75 | "panna_cotta": 73, 76 | "peking_duck": 74, 77 | "pho": 75, 78 | "pizza": 76, 79 | "pork_chop": 77, 80 | "poutine": 78, 81 | "prime_rib": 79, 82 | "pulled_pork_sandwich": 80, 83 | "ramen": 81, 84 | "ravioli": 82, 85 | "red_velvet_cake": 83, 86 | "risotto": 84, 87 | "samosa": 85, 88 | "sashimi": 86, 89 | "scallops": 87, 90 | "seaweed_salad": 88, 91 | "shrimp_and_grits": 89, 92 | "spaghetti_bolognese": 90, 93 | "spaghetti_carbonara": 91, 94 | "spring_rolls": 92, 95 | "steak": 93, 96 | "strawberry_shortcake": 94, 97 | "sushi": 95, 98 | "tacos": 96, 99 | "takoyaki": 97, 100 | "tiramisu": 98, 101 | "tuna_tartare": 99, 102 | "waffles": 100 103 | } 104 | -------------------------------------------------------------------------------- /vllm_inference/utils/foodSeg103_id2label.json: -------------------------------------------------------------------------------- 1 | {"0": "background", "1": "candy", "2": "egg tart", "3": "french fries", "4": "chocolate", "5": "biscuit", "6": "popcorn", "7": "pudding", "8": "ice cream", "9": "cheese butter", "10": "cake", "11": "wine", "12": "milkshake", "13": "coffee", "14": "juice", "15": "milk", "16": "tea", "17": "almond", "18": "red beans", "19": "cashew", "20": "dried cranberries", "21": "soy", "22": "walnut", "23": "peanut", "24": "egg", "25": "apple", "26": "date", "27": "apricot", "28": "avocado", "29": "banana", "30": "strawberry", "31": "cherry", "32": "blueberry", "33": "raspberry", "34": "mango", "35": "olives", "36": "peach", "37": "lemon", "38": "pear", "39": "fig", "40": "pineapple", "41": "grape", "42": "kiwi", "43": "melon", "44": "orange", "45": "watermelon", "46": "steak", "47": "pork", "48": "chicken duck", "49": "sausage", "50": "fried meat", "51": "lamb", "52": "sauce", "53": "crab", "54": "fish", "55": "shellfish", "56": "shrimp", "57": "soup", "58": "bread", "59": "corn", "60": "hamburg", "61": "pizza", "62": "hanamaki baozi", "63": "wonton dumplings", "64": "pasta", "65": "noodles", "66": "rice", "67": "pie", "68": "tofu", "69": "eggplant", "70": "potato", "71": "garlic", "72": "cauliflower", "73": "tomato", "74": "kelp", "75": "seaweed", "76": "spring onion", "77": "rape", "78": "ginger", "79": "okra", "80": "lettuce", "81": "pumpkin", "82": "cucumber", "83": "white radish", "84": "carrot", "85": "asparagus", "86": "bamboo shoots", "87": "broccoli", "88": "celery stick", "89": "cilantro mint", "90": "snow peas", "91": "cabbage", "92": "bean sprouts", "93": "onion", "94": "pepper", "95": "green beans", "96": "French beans", "97": "king oyster mushroom", "98": "shiitake", "99": "enoki mushroom", "100": "oyster mushroom", "101": "white button mushroom", "102": "salad", "103": "other ingredients"} -------------------------------------------------------------------------------- /vllm_inference/utils/llava_med/evaluate_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from 'https://github.com/microsoft/LLaVA-Med/blob/v1.0.0/llava/eval/eval_metrics/evaluate_metrics.py' 3 | """ 4 | 5 | import math 6 | from .utils import * 7 | from .glossary import * 8 | 9 | def bleu(candidate, references, n, weights): 10 | 11 | pn = [] 12 | bp = brevity_penalty(candidate, references) 13 | for i in range(n): 14 | pn.append(modified_precision(candidate, references, i + 1)) 15 | if len(weights) > len(pn): 16 | tmp_weights = [] 17 | for i in range(len(pn)): 18 | tmp_weights.append(weights[i]) 19 | bleu_result = calculate_bleu(tmp_weights, pn, n, bp) 20 | return str(bleu_result) + " (warning: the length of weights is bigger than n)" 21 | elif len(weights) < len(pn): 22 | tmp_weights = [] 23 | for i in range(len(pn)): 24 | tmp_weights.append(0) 25 | for i in range(len(weights)): 26 | tmp_weights[i] = weights[i] 27 | bleu_result = calculate_bleu(tmp_weights, pn, n, bp) 28 | return str(bleu_result) + " (warning: the length of weights is smaller than n)" 29 | else: 30 | bleu_result = calculate_bleu(weights, pn, n, bp) 31 | return str(bleu_result) 32 | 33 | #BLEU 34 | def calculate_bleu(weights, pn, n, bp): 35 | sum_wlogp = 0 36 | for i in range(n): 37 | if pn[i] != 0: 38 | sum_wlogp += float(weights[i]) * math.log(pn[i]) 39 | bleu_result = bp * math.exp(sum_wlogp) 40 | return bleu_result 41 | 42 | #Exact match 43 | def calculate_exactmatch(candidate, reference): 44 | 45 | candidate = normalize_word(candidate) 46 | reference = normalize_word(reference) 47 | 48 | candidate_words = split_sentence(candidate, 1) 49 | reference_words = split_sentence(reference, 1) 50 | count = 0 51 | total = 0 52 | for word in reference_words: 53 | if word in candidate_words: 54 | count += 1 55 | for word in candidate_words: 56 | total += candidate_words[word] 57 | 58 | if total == 0: 59 | return 0 # "0 (warning: length of candidate's words is 0)" 60 | else: 61 | return count / total 62 | 63 | #Exact match with normalization 64 | 65 | def similarity_candidate_prediction(candidate_answer, prediction): 66 | 67 | candidate_answer = split_sentence(candidate_answer, 1) 68 | 69 | count = 0 70 | total = 0 71 | for word in prediction: 72 | if word in candidate_answer: 73 | count += 1 74 | 75 | total = len(candidate_answer) 76 | 77 | if total == 0: 78 | return 0.0 # "0 (warning: length of candidate's words is 0)" 79 | else: 80 | return count / total 81 | 82 | def argmax(lst): 83 | return lst.index(max(lst)) 84 | 85 | def calculate_appearance_with_normalization(prediction, reference, candidate_set): 86 | 87 | prediction = normalize_word(prediction) 88 | reference = normalize_word(reference) 89 | prediction_words = split_sentence(prediction, 1) 90 | reference_words = split_sentence(reference, 1) 91 | 92 | candidate_set = candidate_set['0'] 93 | 94 | similarity_list = [] 95 | candidate_answer_normalized_list = [] 96 | for candidate_answer in candidate_set: 97 | 98 | if isinstance(candidate_answer, int): 99 | candidate_answer = str(candidate_answer) 100 | 101 | candidate_answer = normalize_word(candidate_answer) 102 | candidate_answer_normalized_list.append(candidate_answer) 103 | similarity_list.append(similarity_candidate_prediction(candidate_answer, prediction_words)) 104 | 105 | final_prediction = candidate_answer_normalized_list[argmax(similarity_list)] 106 | 107 | # import pdb; pdb.set_trace() 108 | 109 | if final_prediction == reference: 110 | return 1.0 # 111 | else: 112 | return 0.0 113 | 114 | 115 | 116 | 117 | #F1 118 | def calculate_f1score(candidate, reference): 119 | 120 | candidate = normalize_word(candidate) 121 | reference = normalize_word(reference) 122 | 123 | candidate_words = split_sentence(candidate, 1) 124 | reference_words = split_sentence(reference, 1) 125 | word_set = set() 126 | for word in candidate_words: 127 | word_set.add(word) 128 | for word in reference_words: 129 | word_set.add(word) 130 | 131 | tp = 0 132 | fp = 0 133 | fn = 0 134 | for word in word_set: 135 | if word in candidate_words and word in reference_words: 136 | tp += candidate_words[word] 137 | elif word in candidate_words and word not in reference_words: 138 | fp += candidate_words[word] 139 | elif word not in candidate_words and word in reference_words: 140 | fn += reference_words[word] 141 | 142 | if len(candidate_words) == 0: 143 | return 0, 0, 0 # "0 (warning: length of candidate's words is 0)" 144 | elif len(reference_words) == 0: 145 | return 0, 0, 0 146 | else: 147 | precision = tp / (tp + fp) 148 | recall = tp / (tp + fn) 149 | if tp == 0: 150 | return 0, 0, 0 151 | else: 152 | return 2 * precision * recall / (precision + recall), precision, recall 153 | -------------------------------------------------------------------------------- /vllm_inference/utils/llava_med/glossary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/microsoft/LLaVA-Med/blob/v1.0.0/llava/eval/eval_metrics/glossary.py 3 | """ 4 | import re 5 | 6 | contractions = { 7 | "aint": "ain't", 8 | "arent": "aren't", 9 | "cant": "can't", 10 | "couldve": "could've", 11 | "couldnt": "couldn't", 12 | "couldn'tve": "couldn't've", 13 | "couldnt've": "couldn't've", 14 | "didnt": "didn't", 15 | "doesnt": "doesn't", 16 | "dont": "don't", 17 | "hadnt": "hadn't", 18 | "hadnt've": "hadn't've", 19 | "hadn'tve": "hadn't've", 20 | "hasnt": "hasn't", 21 | "havent": "haven't", 22 | "hed": "he'd", 23 | "hed've": "he'd've", 24 | "he'dve": "he'd've", 25 | "hes": "he's", 26 | "howd": "how'd", 27 | "howll": "how'll", 28 | "hows": "how's", 29 | "Id've": "I'd've", 30 | "I'dve": "I'd've", 31 | "Im": "I'm", 32 | "Ive": "I've", 33 | "isnt": "isn't", 34 | "itd": "it'd", 35 | "itd've": "it'd've", 36 | "it'dve": "it'd've", 37 | "itll": "it'll", 38 | "let's": "let's", 39 | "maam": "ma'am", 40 | "mightnt": "mightn't", 41 | "mightnt've": "mightn't've", 42 | "mightn'tve": "mightn't've", 43 | "mightve": "might've", 44 | "mustnt": "mustn't", 45 | "mustve": "must've", 46 | "neednt": "needn't", 47 | "notve": "not've", 48 | "oclock": "o'clock", 49 | "oughtnt": "oughtn't", 50 | "ow's'at": "'ow's'at", 51 | "'ows'at": "'ow's'at", 52 | "'ow'sat": "'ow's'at", 53 | "shant": "shan't", 54 | "shed've": "she'd've", 55 | "she'dve": "she'd've", 56 | "she's": "she's", 57 | "shouldve": "should've", 58 | "shouldnt": "shouldn't", 59 | "shouldnt've": "shouldn't've", 60 | "shouldn'tve": "shouldn't've", 61 | "somebody'd": "somebodyd", 62 | "somebodyd've": "somebody'd've", 63 | "somebody'dve": "somebody'd've", 64 | "somebodyll": "somebody'll", 65 | "somebodys": "somebody's", 66 | "someoned": "someone'd", 67 | "someoned've": "someone'd've", 68 | "someone'dve": "someone'd've", 69 | "someonell": "someone'll", 70 | "someones": "someone's", 71 | "somethingd": "something'd", 72 | "somethingd've": "something'd've", 73 | "something'dve": "something'd've", 74 | "somethingll": "something'll", 75 | "thats": "that's", 76 | "thered": "there'd", 77 | "thered've": "there'd've", 78 | "there'dve": "there'd've", 79 | "therere": "there're", 80 | "theres": "there's", 81 | "theyd": "they'd", 82 | "theyd've": "they'd've", 83 | "they'dve": "they'd've", 84 | "theyll": "they'll", 85 | "theyre": "they're", 86 | "theyve": "they've", 87 | "twas": "'twas", 88 | "wasnt": "wasn't", 89 | "wed've": "we'd've", 90 | "we'dve": "we'd've", 91 | "weve": "we've", 92 | "werent": "weren't", 93 | "whatll": "what'll", 94 | "whatre": "what're", 95 | "whats": "what's", 96 | "whatve": "what've", 97 | "whens": "when's", 98 | "whered": "where'd", 99 | "wheres": "where's", 100 | "whereve": "where've", 101 | "whod": "who'd", 102 | "whod've": "who'd've", 103 | "who'dve": "who'd've", 104 | "wholl": "who'll", 105 | "whos": "who's", 106 | "whove": "who've", 107 | "whyll": "why'll", 108 | "whyre": "why're", 109 | "whys": "why's", 110 | "wont": "won't", 111 | "wouldve": "would've", 112 | "wouldnt": "wouldn't", 113 | "wouldnt've": "wouldn't've", 114 | "wouldn'tve": "wouldn't've", 115 | "yall": "y'all", 116 | "yall'll": "y'all'll", 117 | "y'allll": "y'all'll", 118 | "yall'd've": "y'all'd've", 119 | "y'alld've": "y'all'd've", 120 | "y'all'dve": "y'all'd've", 121 | "youd": "you'd", 122 | "youd've": "you'd've", 123 | "you'dve": "you'd've", 124 | "youll": "you'll", 125 | "youre": "you're", 126 | "youve": "you've", 127 | } 128 | 129 | manual_map = { 130 | "none": "0", 131 | "zero": "0", 132 | "one": "1", 133 | "two": "2", 134 | "three": "3", 135 | "four": "4", 136 | "five": "5", 137 | "six": "6", 138 | "seven": "7", 139 | "eight": "8", 140 | "nine": "9", 141 | "ten": "10", 142 | } 143 | articles = ["a", "an", "the"] 144 | period_strip = re.compile("(?!<=\d)(\.)(?!\d)") 145 | comma_strip = re.compile("(\d)(\,)(\d)") 146 | punct = [ 147 | ";", 148 | r"/", 149 | "[", 150 | "]", 151 | '"', 152 | "{", 153 | "}", 154 | "(", 155 | ")", 156 | "=", 157 | "+", 158 | "\\", 159 | "_", 160 | "-", 161 | ">", 162 | "<", 163 | "@", 164 | "`", 165 | ",", 166 | "?", 167 | "!", 168 | ] 169 | 170 | 171 | def normalize_word(token): 172 | _token = token 173 | for p in punct: 174 | if (p + " " in token or " " + p in token) or ( 175 | re.search(comma_strip, token) != None 176 | ): 177 | _token = _token.replace(p, "") 178 | else: 179 | _token = _token.replace(p, " ") 180 | token = period_strip.sub("", _token, re.UNICODE) 181 | 182 | _token = [] 183 | temp = token.lower().split() 184 | for word in temp: 185 | word = manual_map.setdefault(word, word) 186 | if word not in articles: 187 | _token.append(word) 188 | for i, word in enumerate(_token): 189 | if word in contractions: 190 | _token[i] = contractions[word] 191 | token = " ".join(_token) 192 | token = token.replace(",", "") 193 | return token -------------------------------------------------------------------------------- /vllm_inference/utils/llava_med/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copied from https://github.com/microsoft/LLaVA-Med/blob/v1.0.0/llava/eval/eval_metrics/utils.py 3 | """ 4 | from collections import defaultdict 5 | import re 6 | import math 7 | 8 | def brevity_penalty(candidate, references): 9 | c = len(candidate) 10 | ref_lens = (len(reference) for reference in references) 11 | r = min(ref_lens, key=lambda ref_len: (abs(ref_len - c), ref_len)) 12 | 13 | if c > r: 14 | return 1 15 | else: 16 | return math.exp(1 - r / c) 17 | 18 | def modified_precision(candidate, references, n): 19 | max_frequency = defaultdict(int) 20 | min_frequency = defaultdict(int) 21 | 22 | candidate_words = split_sentence(candidate, n) 23 | 24 | for reference in references: 25 | reference_words = split_sentence(reference, n) 26 | for word in candidate_words: 27 | max_frequency[word] = max(max_frequency[word], reference_words[word]) 28 | for word in candidate_words: 29 | min_frequency[word] = min(max_frequency[word], candidate_words[word]) 30 | P = sum(min_frequency.values()) / sum(candidate_words.values()) 31 | return P 32 | 33 | def split_sentence(sentence, n): 34 | words = defaultdict(int) 35 | # tmp_sentence = re.sub("[^a-zA-Z ]", "", sentence) 36 | tmp_sentence = sentence 37 | tmp_sentence = tmp_sentence.lower() 38 | tmp_sentence = tmp_sentence.strip().split() 39 | length = len(tmp_sentence) 40 | for i in range(length - n + 1): 41 | tmp_words = " ".join(tmp_sentence[i: i + n]) 42 | if tmp_words: 43 | words[tmp_words] += 1 44 | return words -------------------------------------------------------------------------------- /vllm_inference/utils/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rouge import Rouge 3 | import random 4 | 5 | # copied from https://github.com/microsoft/LMOps/blob/main/uprise/src/utils/metric.py 6 | def rouge(labels, preds, return_list=False): 7 | r1s, r2s, rls = [], [], [] 8 | r = Rouge() 9 | for i in range(len(labels)): 10 | try: 11 | scores = r.get_scores(preds[i], labels[i])[0] 12 | r1s.append(scores["rouge-1"]["f"]) 13 | r2s.append(scores["rouge-2"]["f"]) 14 | rls.append(scores["rouge-l"]["f"]) 15 | except Exception as e: 16 | r1s.append(0) 17 | r2s.append(0) 18 | rls.append(0) 19 | if return_list: 20 | return rls 21 | r1 = sum(r1s) / len(r1s) 22 | r2 = sum(r2s) / len(r2s) 23 | rl = sum(rls) / len(rls) 24 | return r1, r2, rl 25 | 26 | # modified based on https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py 27 | # ----------- Process Multi-choice ------------- 28 | def parse_multi_choice_response(response, all_choices, index2ans, random_seed): 29 | """ 30 | Parse the prediction from the generated response. 31 | Return the predicted index e.g., A, B, C, D. 32 | """ 33 | random_flag = False # whether is random selected answer 34 | for char in [',', '.', '!', '?', ';', ':', "'"]: 35 | response = response.strip(char) 36 | response = " " + response + " " # add space to avoid partial match 37 | 38 | index_ans = True 39 | ans_with_brack = False 40 | candidates = [] 41 | for choice in all_choices: # e.g., (A) (B) (C) (D) 42 | if f'({choice})' in response: 43 | candidates.append(choice) 44 | ans_with_brack = True 45 | 46 | if len(candidates) == 0: 47 | for choice in all_choices: # e.g., A B C D 48 | if f' {choice} ' in response: 49 | candidates.append(choice) 50 | # if all above doesn't get candidates, check if the content is larger than 5 tokens and try to parse the example 51 | if len(candidates) == 0 and len(response) > 5: 52 | for index, ans in index2ans.items(): 53 | if ans.lower() in response.lower(): 54 | candidates.append(index) 55 | index_ans = False # it's content ans. 56 | 57 | if len(candidates) == 0: # still not get answer, randomly choose one. 58 | pred_index = random.Random(random_seed).choice(all_choices) 59 | random_flag = True 60 | elif len(candidates) > 1: 61 | start_indexes = [] 62 | if index_ans: 63 | if ans_with_brack: 64 | for can in candidates: 65 | index = response.rfind(f'({can})') 66 | start_indexes.append(index) # -1 will be ignored anyway 67 | else: 68 | for can in candidates: 69 | index = response.rfind(f" {can} ") 70 | start_indexes.append(index) 71 | else: 72 | for can in candidates: 73 | index = response.lower().rfind(index2ans[can].lower()) 74 | start_indexes.append(index) 75 | # get the last one 76 | pred_index = candidates[np.argmax(start_indexes)] 77 | else: # if only one candidate, use it. 78 | pred_index = candidates[0] 79 | 80 | return pred_index, random_flag 81 | 82 | def parse_multi_label_response(response: str, index2ans: dict): 83 | """ 84 | Parse the prediction from the generated response. 85 | """ 86 | # to avoid partial match 87 | # for example, a sentence having egg tart would match both egg and egg tart, however, only egg tart should be matched. 88 | # so we sort the index2ans by length and match the longer first 89 | # if one phrase is matched, we replace the matched phrase with '' in the response 90 | index2ans = dict(sorted(index2ans.items(), key=lambda item: len(item[1]), reverse=True)) 91 | candidates = [] 92 | response = response.lower() 93 | for index, ans in index2ans.items(): 94 | if ans.lower() in response: 95 | candidates.append(index) 96 | response = response.replace(ans.lower(), '') 97 | return candidates 98 | 99 | 100 | from sklearn.metrics import recall_score, f1_score 101 | def compute_multi_label_scores(label_indices: list, pred_indices: list, category_index_start: int, category_index_end: int, metric_name: str): 102 | 103 | # Convert to binary vectors 104 | y_true = [1 if i in label_indices else 0 for i in range(category_index_start, category_index_end + 1)] 105 | y_pred = [1 if i in pred_indices else 0 for i in range(category_index_start, category_index_end + 1)] 106 | 107 | # Calculate all metrics 108 | if metric_name == 'recall': 109 | recall = recall_score(y_true, y_pred, zero_division=1) 110 | return recall * 100 111 | elif metric_name == 'f1': 112 | f1 = f1_score(y_true, y_pred, zero_division=1) 113 | return f1 * 100 -------------------------------------------------------------------------------- /vllm_inference/utils/nutrition5k_ingredients.py: -------------------------------------------------------------------------------- 1 | all_ingredients = ['cottage cheese', 'strawberries', 'garden salad', 'bacon', 'potatoes', 'caesar salad', 'cauliflower', 'scrambled eggs', 'wild rice', 'steak', 'cheese pizza', 'olives', 'berries', 'asparagus', 'hash browns', 'brussels sprouts', 'pasta salad', 'turkey', 'bread', 'duck', 'squash', 'guacamole', 'brown rice', 'artic char', 'beef', 'white rice', 'broccoli', 'chicken', 'mixed greens', 'lettuce', 'cucumbers', 'tomatoes', 'bell peppers', 'celery', 'blue cheese', 'spinach (raw)', 'cantaloupe', 'pineapple', 'sausage', 'raspberries', 'blackberries', 'avocado', 'green beans', 'bean sprouts', 'carrot', 'mushroom', 'corn', 'ham', 'fish', 'tofu', 'shrimp', 'cheese', 'nuts', 'apple', 'banana', 'kiwi', 'lemon', 'orange', 'oatmeal', 'tortilla', 'potato chips', 'noodles', 'bean(seed)', 'alcohol', 'grape juice', 'ground beef', 'syrup', 'grapefruits', 'fruit punch', 'figs', 'macaroni and cheese', 'kale', 'radishes', 'pasta', 'refried beans', 'brisket', 'almonds', 'protein powder', 'okra', 'turkey bacon', 'pickles', 'pecans', 'muenster cheese', 'brie cheese', 'apple juice', 'raisin bran', 'coleslaw', 'margarine', 'spaghetti', 'ground pork', 'herring', 'wine', 'alfalfa', 'onions', 'artichokes', 'mussels', 'falafel', 'multigrain bread', 'marshmallows', 'oats', 'leeks', 'snow peas', 'pretzels', 'strudels', 'grilled chicken', 'half and half', 'mixed nuts', 'spring rolls', 'french toast', 'barbecue sauce', 'clams', 'honeydew melons', 'ketchup', 'polenta', 'lemonade', 'mayonnaise', 'tacos', 'eggnog', 'tabouli', 'gelatin', 'watermelon', 'garlic', 'crawfish', 'walnuts', 'octopus', 'mustard', 'pastries', 'egg whites', 'raisins', 'rye bread', 'dark chocolate', 'croissants', 'shallots', 'biscuits', 'tilapia', 'poached eggs', 'succotash', 'seafood', 'egg rolls', 'caesar dressing', 'tuna salad', 'india pale ale beer', 'veal', 'sorbet', 'bison', 'scallops', 'turkey breast', 'parmesan cheese', 'sushi', 'swordfish', 'agave nectar', 'cabbage', 'bulgur', 'brown sugar', 'chicken thighs', 'paella', 'colby cheese', 'gumbo', 'apple cider', 'chow mein', 'olive oil', 'parsnips', 'kidney beans', 'chowders', 'skim milk', 'grapefruit juice', 'tamales', 'lasagna', 'whole wheat bread', 'gravy', 'almond butter', 'cherry pie', 'chestnuts', 'bread crumbs', 'apple pie', 'french dressing', 'tempeh', 'pies', 'eggs', 'rice noodles', 'milk shakes', 'soy yogurt', 'chips', 'ravioli', 'collards', 'green peas', 'chickpeas', 'carrot cake', 'vinegar', 'hamburgers', 'cheesecake', 'yam', 'hot dogs', 'fried rice', 'wafers', 'cod', 'salads', 'breadsticks', 'swiss cheese', 'buttermilk', 'blueberries', 'enchiladas', 'chili', 'pork chops', 'corn on the cob', 'muffins', 'fried chicken', 'ice cream cones', 'pastrami', 'chocolate milk', 'plums', 'burgers', 'pears', 'seeds', 'squid', 'tostadas', 'lamb', 'brazil nuts', 'macaroni', 'pita bread', 'vegetable oil', 'pine nuts', 'macadamia nuts', 'potato salad', 'veggie burgers', 'pasta sauce', 'calamari', 'cornmeal', 'bologna', 'frostings', 'deprecated', 'oysters', 'vodka', 'smoothies', 'pecan pie', 'roast beef', 'hot chocolate', 'hominy', 'custard', 'smoked salmon', 'scones', 'tortellini', 'mozzarella cheese', 'chicken breast', 'wheat bread', 'baked beans', 'blueberry pie', 'hard boiled eggs', 'cashews', 'rolls', 'tortilla chips', 'broth', 'tahini', 'mashed potatoes', 'nougat', 'fish oil', 'brownies', 'popcorn', 'corned beef', 'iced tea', 'pumpkins', 'yogurt', 'mousse', 'bagels', 'sardines', 'pot pies', 'coconuts', 'soft serve ice creams', 'onion rings', 'cocoa', 'peaches', 'cupcakes', 'sandwiches', 'trail mix', 'frozen yogurt', 'chocolate cake', 'chimichangas', 'beer', 'omelets', 'souffle', 'edamame', 'energy drinks', 'salad dressing', 'ribs', 'trout', 'romano cheese', 'toast', 'empanadas', 'ranch dressing', 'cheddar cheese', 'salt', 'spreads', 'sandwich cookies', 'granola', 'provolone cheese', 'sundaes', 'pepperoni', 'focaccia', 'soy milk', 'egg yolks', 'english muffins', 'pie crust', 'chocolate chip cookies', 'puddings', 'fudge', 'croutons', 'stuffing', 'lamb chops', 'rum', 'corn chips', 'granola bars', 'pepper', 'peanut butter', 'sourdough bread', 'corn dogs', 'roast chicken', 'flounder', 'flour', 'feta cheese', 'black beans', 'water', 'garlic bread', 'salami', 'chocolate', 'nectarines', 'chicken soup', 'root beer', 'sugar', 'pate nutrition', 'haddock', 'pancakes', 'salmon', 'potato bread', 'white bread', 'eggplant', 'peanuts', 'waffles', 'vinaigrette', 'grits', 'lo mein', 'cream', 'turnips', 'cranberry juice', 'peach', 'roast pork', 'turnover', 'baby carrots', 'egg salad', 'mangos', 'miso soup', 'mackerel', 'calzones', 'burritos', 'soy nuts', 'crackers', 'french fries', 'lentils', 'mixed vegetables', 'honey', 'risotto', 'roast turkey', 'ice creams', 'meatloaf', 'sunflower seeds', 'quiche', 'sauerkraut', 'lime', 'donuts', 'fritters', 'mahi mahi', 'dumplings', 'chilaquiles', 'snapper', 'peas', 'fruit salad', 'naan', 'quinoa', 'catfish', 'lima beans', 'pizza', 'gouda cheese', 'cheeseburgers', 'italian dressing', 'coffee', 'curries', 'eel', 'candies', 'fried eggs', 'chicken drumsticks', 'margarita', 'cobbler', 'pepperoni pizza', 'teriyaki sauce', 'gorgonzola cheese', 'string cheese', 'wheat beer', 'gyros', 'champagne', 'raisin bread', 'cornbread', 'wraps', 'tuna', 'fruit cocktail', 'potato skins', 'seaweed', 'sherbet', 'horseradish', 'jambalaya', 'greek salad', 'goulash', 'stews', 'havarti cheese', 'pilaf', 'chutney', 'cereal', 'ground chicken', 'fajitas', 'chicken wings', 'hard cider', 'ice pop', 'ground turkey', 'crab', 'plate only', 'prunes', 'goat cheese', 'tomato soup', 'muesli', 'white wine', 'gnocchi', 'apricots', 'licorice', 'red potatoes', 'roasted potatoes', 'sour cream', 'jerky', 'coffee creamer', 'sweet potato', 'chicken nuggets', 'anchovies', 'grapes', 'butter', 'hummus', 'crispbread', 'dates', 'papayas', 'stout beer', 'nutrition bars', 'jalapenos', 'cappuccino', 'buns', 'taco shells', 'maple syrup', 'deprecated', 'rice cakes', 'cookies', 'barley', 'sun dried tomatoes', 'deprecated', 'cereal bars', 'nectar', 'samosas', 'pork', 'beets', 'baked potatoes', 'american cheese', 'pistachios', 'relish', 'camembert cheese', 'carp', 'deprecated', 'cherry tomatoes', 'couscous', 'cranberries', 'quesadillas', 'salsa', 'almond milk', 'light beer', 'sweet rolls', 'challah', 'ice cream soda', 'flatbread', 'rice', 'water chestnuts', 'chicken salad', 'zucchini', 'cakes', 'crepes', 'mexican cheese', 'lobster', 'tea', 'nachos', 'taco salad', 'lager', 'jams', 'cream cheese', 'capers', 'halibut', 'orange juice', 'latte', 'taquitos', 'cherries', 'milk', 'applesauce', 'vegetable juice', 'meatballs', 'pinto beans', 'mandarin oranges', 'soy sauce', 'tempura', 'chicken apple sausage', 'spinach (cooked)', 'kimchi', 'millet', 'wheat berry', 'arugula', 'rosemary', 'chard', 'thyme', 'oregano', 'lemon juice', 'basil', 'tatsoi', 'cilantro', 'parsley', 'bok choy', 'mustard greens', 'chive', 'celery root', 'chayote squash', 'endive', 'pumpkin seeds', 'pesto', 'orzo', 'country rice', 'green onions', 'banana with peel', 'orange with peel', 'jicama', 'nopales', 'ginger', 'tomatillo', 'white beans', 'chia seeds', 'corn starch', 'greek yogurt', 'balsamic vinegar', 'corn nuts', 'coconut milk', 'pizza dough', 'fennel', 'pizza sauce', 'pomegranate', 'nut cheese', 'soy sausage', 'oil'] --------------------------------------------------------------------------------