├── LICENSE ├── README.md ├── allava ├── constants.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_phi.py │ │ ├── llava_stablelm_1_6b.py │ │ ├── llava_vicuna.py │ │ └── phi │ │ │ ├── configuration_phi.py │ │ │ └── modeling_phi.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py └── serve │ ├── cli.py │ ├── run_inference_allava-phi2.py │ ├── run_inference_allava-phi3.py │ └── run_inference_allava-stablelm2.py ├── assets ├── llavas.png ├── pipeline.jpg ├── pipeline.pdf └── training_datasets_by_stage.jpg ├── download ├── download_laion.sh ├── download_text.sh ├── download_vflan.sh └── legacy │ └── laion │ ├── download_images_from_url.py │ └── download_laion_from_url.sh ├── prompts ├── instructions_for_captions.txt ├── prompt_for_laion.txt └── prompt_for_vflan.txt ├── requirements.txt └── scripts └── zip_images.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ALLaVA: Harnessing GPT4V-synthesized Data for A Lite Vision-Language Model 2 | 3 | 4 | 7 | 8 |

9 | ⚡ALLaVA is a project that provides a large-scale GPT4V-synthesized dataset for training LVLMs.⚡ 10 |

11 | 12 | 16 | 17 |

18 | Python Version 19 | PyTorch Version 20 | Transformers Version 21 |

22 | 23 |

24 | 📃 Paper • 🌐 Demo 25 |

26 |

27 | 🤗 ALLaVA-4V Dataset 28 |

29 | 30 |

31 | 🤗 ALLaVA-Phi3-mini-128k 32 | • 🤗 ALLaVA-StableLM2-1_6B 33 | • 🤗 ALLaVA-Phi2-2_7B 34 |

35 | 36 | 40 | 41 | ## ✨ Updates 42 | - [06/25/2024]: We release [ALLaVA-Phi3-mini-128k](https://huggingface.co/FreedomIntelligence/ALLaVA-Phi3-mini-128k), [ALLaVA-StableLM2-1_6B](https://huggingface.co/FreedomIntelligence/ALLaVA-StableLM2-1_6B), [ALLaVA-Phi2-2_7B](https://huggingface.co/FreedomIntelligence/ALLaVA-Phi2-2_7B) which all support loading from 🤗 repo. 43 | - [03/01/2024]: The huggingface repo of **ALLaVA-3B-Longer (recommended)** and ALLaVA-3B are updated, which now supports the `from_pretrained` method to load models. 44 | - [02/29/2024]: The huggingface repo of ALLaVA-4V dataset and [download scripts](#data-preparation) are updated. 45 | - [02/21/2024]: We are thrilled to release 1) **1.4M** data for training LVLMs, 2) two version of our ALLaVA-3B models, 3) inference code and 4) tech report. 46 | 47 | 48 | 49 | ## 📚 ALLaVA-4V Data 50 | 51 | ### Generation Pipeline 52 | 53 | 54 |
55 | pipeline 56 |
57 | 58 | 59 | * LAION 60 | 61 | We leverage the superb GPT-4V to generate captions and complex reasoning QA pairs. Prompt is [here](prompts/prompt_for_laion.txt). 62 | 63 | * Vison-FLAN 64 | 65 | We leverage the superb GPT-4V to generate captions and detailed answer for the original instructions. Prompt is [here]( 66 | prompts/prompt_for_vflan.txt). 67 | 68 | * Wizard 69 | 70 | We regenerate the answer of Wizard_evol_instruct with GPT-4-Turbo. 71 | 72 | ### Dataset Cards 73 | All datasets can be found [here](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V). 74 | The structure of naming is shown below: 75 | ```bash 76 | ALLaVA-4V 77 | ├── ALLaVA-Caption-4V 78 | │ ├── ALLaVA-Caption-LAION-4V 79 | │ └── ALLaVA-Caption-VFLAN-4V 80 | ├── ALLaVA-Instruct-4V 81 | │ ├── ALLaVA-Instruct-LAION-4V 82 | │ └── ALLaVA-Instruct-VFLAN-4V 83 | ├── Evol-Instruct-GPT4-Turbo-143K 84 | ``` 85 | 86 | The folder structure of the huggingface dataset space: 87 | ```bash 88 | ALLaVA-4V 89 | ├── allava_laion/ 90 | │ ├── ALLaVA-Caption-LAION-4V.json 91 | │ ├── ALLaVA-Instruct-LAION-4V.json 92 | | └── images.zip 93 | ├── allava_vflan/ 94 | │ ├── ALLaVA-Caption-VFLAN-4V.json 95 | │ └── ALLaVA-Instruct-VFLAN-4V.json 96 | ├── allava_text/ 97 | │ └── Evol-Instruct-GPT4-Turbo-143K.json 98 | ``` 99 | **We do NOT own right to any image contained within the "images.zip" file. We collate the images and upload this file in request of the community to facilitate the data preparation process.** 100 | 101 | Here we provide detailed information of each subset. 102 | 103 | | Name | #Samples | Image Source | Instruction Source | Answer Source | 104 | | --- | ---: | ---: | ---: | ---: | 105 | |ALLaVA-Caption-LAION-4V* | 505,588 | LAION (web) | [Handcrafted](prompts/instructions_for_captions.txt) | GPT-4V 106 | |ALLaVA-Caption-VFLAN-4V**| 202,552 | [Vision FLAN](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/tree/main) | [Handcrafted](prompts/instructions_for_captions.txt) | GPT-4V 107 | |ALLaVA-Instruct-LAION-4V* | 505,588 | LAION (web) | GPT-4V | GPT-4V 108 | |ALLaVA-Instruct-VFLAN-4V**| 203,065 | [Vision FLAN](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/tree/main) | [Vision FLAN](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/tree/main) *** | GPT-4V 109 | Evol-Instruct-GPT4-Turbo-143K |143,000 | / | [Wizard](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k) | GPT-4-Turbo 110 | 111 | **ALLaVA-Caption-LAION-4V* and *ALLaVA-Instruct-LAION-4V* share the same set of images. 112 | 113 | ***ALLaVA-Caption-VFLAN-4V* and *ALLaVA-Instruct-VFLAN-4V* share the same set of images. 114 | 115 | ***In *ALLaVA-Caption-VFLAN-4V* and *ALLaVA-Instruct-VFLAN-4V*, we additionally include 30K samples which are distilled following the same manner for LAION (i.e. Distilling Captions and Complex Reasoning QA pairs). We mix them up with the other data whose instructions are from the original VFLAN. 116 | 117 | 118 | ### Data Preparation 119 | 120 | #### LAION 121 | 122 | Specify the `laion_root` in `./download/download_laion.sh` and run 123 | ```shell 124 | sh ./download/download_laion.sh 125 | ``` 126 | 137 | 138 | 139 | #### VFLAN 140 | Specify the `vflan_root` in `./download/download_vflan.sh` and run 141 | ```shell 142 | sh ./download/download_vflan.sh 143 | ``` 144 | 145 | 146 | #### Evol-Instruct-GPT4-Turbo-143K 147 | Specify the `text_root` in `./download/download_text.sh` and run 148 | ```shell 149 | sh ./download/download_text.sh 150 | ``` 151 | 152 | 153 | ### Sample Demonstration 154 |
ALLaVA-Caption-LAION-4V 155 | 156 | ``` 157 | { 158 | "id": "allava_laion_cap_100277305", 159 | "image": "allava_laion/images/100277305.jpeg", 160 | "conversations": [ 161 | { 162 | "from": "human", 163 | "value": "\nPlease depict the image in words." 164 | }, 165 | { 166 | "from": "gpt", 167 | "value": "The image displays a silver ring with a design that resembles a stylized leaf. The leaf design is split into two symmetrical halves, separated by a gap that runs along the ring's band. Each half of the leaf pattern is textured with detailed, vein-like striations that give it a realistic look. The edges of the leaf design are embellished with small, clear gemstones that could possibly be diamonds, adding a sparkle to the ring. The band of the ring is wide, supporting the large leaf motif, and it tapers slightly as it moves away from the leaf design. The silver metal has a high polish, reflecting light brightly and providing a contrast to the textured leaf surface. The craftsmanship suggests a mix of organic inspiration with a touch of elegance due to the gems 168 | tone accents." 169 | } 170 | ], 171 | "original_caption": "190952cz Pandora Tropical Palm Leaf Ring", 172 | "url": "https://theartofpandora.com/wp-content/uploads/2016/10/190952CZ-1-600x600.jpg" 173 | } 174 | ``` 175 |
176 | 177 | 178 |
ALLaVA-Caption-VFLAN-4V 179 | 180 | ``` 181 | { 182 | "id": "allava_vflan_cap_100000", 183 | "image": "allava_vflan/images/images_191task_1k/PACS+guitar_image_category_classification_830_n03272010_10795-5.png", 184 | "conversations": [ 185 | { 186 | "from": "human", 187 | "value": "\nHow would you describe the image to someone who cannot see it?" 188 | }, 189 | { 190 | "from": "gpt", 191 | "value": "The image shows a simplistic line drawing of an electric guitar. The guitar is oriented diagonally from the top left corner to the bottom right corner of the image. It has a long, narrow neck with tuning pegs at the top, resembling a stylized \"T\" shape. The body of the guitar is irregular, with two cutaways on either side of the neck, giving it an asymmetric appearance. The body is also characterized by sharp angles rather than curves, which is not typical for most guitars. There are four horizontal lines across the body, likely representing strings. The drawing is black and white, with no shading or color, and appears to be hand-drawn with a single continuous line without much detail. The background is plain and white." 192 | } 193 | ] 194 | } 195 | ``` 196 |
197 | 198 |
ALLaVA-Instruct-LAION-4V 199 | 200 | ``` 201 | { 202 | "id": "allava_laion_inst_100277305", 203 | "image": "allava_laion/images/100277305.jpeg", 204 | "conversations": [ 205 | { 206 | "from": "human", 207 | "value": "\nHow does the inclusion of gemstones affect the overall aesthetic and possible value of the ring?" 208 | }, 209 | { 210 | "from": "gpt", 211 | "value": "The inclusion of gemstones along the edges of the leaf design enhances the ring's aesthetic by adding a touch of brilliance and luxury. The small, clear gemstones catch the light, creating a sparkling effect that contrasts with the silver's polished surface and the leaf's textured details. This sparkling accent likely increases the ring's visual appeal, making it stand out more as an accessory. In terms of value, the presence of gemstones, which could be diamonds, would likely raise the ring's monetary worth. Gemstones, especially diamonds, are prized for their beauty and durability, and their addition to jewelry pieces often signifies a higher quality and class of the item, thereby increasing its market value." 212 | } 213 | ], 214 | "original_caption": "190952cz Pandora Tropical Palm Leaf Ring", 215 | "url": "https://theartofpandora.com/wp-content/uploads/2016/10/190952CZ-1-600x600.jpg" 216 | } 217 | ``` 218 |
219 | 220 | 221 |
ALLaVA-Instruct-VFLAN-4V 222 | 223 | ``` 224 | { 225 | "id": "allava_vflan_inst_9", 226 | "image": "allava_vflan/images/images_191task_1k/VQA-E+visual_question_answering_574_COCO_train2014_000000119456.jpg", 227 | "conversations": [ 228 | { 229 | "from": "human", 230 | "value": "\nYou are provided with an image and a question related to the image. Answer the question based on the information given in the image.\n\nIs bigfoot in the picture?" 231 | }, 232 | { 233 | "from": "gpt", 234 | "value": "The creature in the image is not Bigfoot, also known as Sasquatch. Bigfoot is typically described in folklore and cryptozoology as a large, bipedal humanoid creature with ape-like features, standing upwards of 6 to 9 feet tall. The animal present in the photograph is quadrupedal, possesses a body structure similar to that of a pig, and has features such as long hair and tusks, which are characteristic of a wild boar. Additionally, there is no evidence of bipedal creatures or any human-like figures that would suggest the presence of Bigfoot. The setting and visual clues do not align with the commonly reported sightings of Bigfoot, which often describe it as a solitary figure standing upright among trees. Therefore, based on the visual information available in the image, the conclusion is that Bigfoot is not depicted here." 235 | } 236 | ] 237 | } 238 | ``` 239 |
240 | 241 |
Evol-Instruct-GPT4-Turbo-143K 242 | 243 | ``` 244 | { 245 | "idx": "heR0vZB", 246 | "conversations": [ 247 | { 248 | "from": "human", 249 | "value": "As an online platform teacher named Aimee, you possess impeccable credentials which include a Bachelor of Science degree in Industrial and Labor Relations from Cornell University, expertise in the English language, and intermediate proficiency in both Chinese and Spanish. Additionally, your professional experience as a STEAM teacher at UN Women in Singapore has honed your skills in teaching children from the ages of 6-11 and working with students from all levels of education. Your exceptional teaching abilities in spoken English and pronunciation paired with your personal strengths of being informed, patient, and engaging make you an ideal teacher for students seeking to improve their English language skills. Can you provide a short, concise, and unique English self-introduction in bullet point form that would attract students to enroll in your course?" 250 | }, 251 | { 252 | "from": "gpt", 253 | "value": "Sure, here are some bullet points for your English self-introduction:\n\n- Hi there! My name is Aimee and I'm thrilled to be your online English teacher!\n- With a Bachelor's degree from Cornell University and experience teaching STEAM at UN Women in Singapore, I know how to help students learn effectively.\n- I'm passionate about language and have intermediate proficiency in both Chinese and Spanish, so I understand the challenges of learning a new language.\n- In my classes, I focus on improving spoken English and pronunciation, as well as building vocabulary and grammar skills.\n- But learning English doesn't have to be boring! I'm an engaging and patient teacher who loves to have fun and make learning enjoyable.\n- So, whether you're a beginner or an advanced student, I'm confident that I can help you improve your English skills and achieve your goals. Let's get started!" 254 | } 255 | ] 256 | } 257 | ``` 258 |
259 | 260 | ## Benchmark Result 261 | 262 | Our models [**ALLaVA-Phi3-mini-128k**](https://huggingface.co/FreedomIntelligence/ALLaVA-Phi3-mini-128k), 263 | [**ALLaVA-StableLM2-1_6B**](https://huggingface.co/FreedomIntelligence/ALLaVA-StableLM2-1_6B) 264 | and [**ALLaVA-Phi2-2_7B**](https://huggingface.co/FreedomIntelligence/ALLaVA-Phi2-2_7B) 265 | achieve competitive results on 17 benchmarks. 266 | 267 | 268 | | Models | Vicuna-80 | GQA | HallusionBench | MME-P | MMVP | TouchStone | TextVQA | MME-C | MathVista | MM-Vet | MMMU-val | SQA (img) | LLaVA (In-the-Wild) | MLLM-Bench | MMB-en | MMB-cn | SEEDBench (img, v1) | 269 | |---------------------------|-----------|-----|-------|-------|------|----|---------|-------|----|--------|-----------------|---------|---------------|----|--------|--------|--------------------| 270 | | **Large VLMs** | | | | | | | | | | | | | | | | | | 271 | | BLIP-2 | - | - | - | - | - | - | - | - | - | 22.4 | 34.4 | - | - | 3.0*| - | - | 49.7 | 272 | | InstructBLIP | - | 49.5| - | - | - | - | - | - | - | 25.6 | - | - | 58.2 | - | 44.0 | - | - | 273 | | Qwen-VL-Chat | - | 57.5| - | 1487.6| - | - | 61.5 | 360.7 | - | 31.1 | - | 68.2 | - | - | 60.6 | 56.7 | 65.4 | 274 | | LLaVA-1.5-7B | 13.8* | 62.0| 36.6* | 1504.4*| 24.7*| 594.9*| 58.2| 324.6*| 25.0*| 31.1| 35.1*| 66.8| 65.4| 23.0*| 64.3| 58.3| 66.1| 275 | | LLaVA-1.5-13B | 22.5 | 63.3| 36.5* | 1531.3 | 38.0*| 617.7*| 61.3| 295.4| 28.3*| 35.4| 34.4*| 71.6| 72.5| -| 67.7| 63.6| 68.2| 276 | | LVIS-7B | - | 62.6| - | - | - | - | 58.7 | - | - | 31.5 | - | - | 67.0 | 29.0*| 66.2 | - | - | 277 | | LVIS-13B | - | 63.6*| - | - | - | - | 62.5* | - | - | 37.4* | - | - | 71.3* | - | 68.0* | - | - | 278 | | ShareGPT4V-7B | 13.8* | 63.3| 36.0* | 1540.1*| 34.0*| 637.2*| 60.4| 346.1*| 24.7*| 37.6| 35.4*| 68.4*| 72.6| 30.2*| 68.8| 61.0*| 69.7| 279 | | ShareGPT4V-13B | 17.5* | 64.8| 39.0* | 1576.1*| 35.3*| 648.7*| 62.2| 309.3*| 28.8*| 43.1| 35.6*| 70.0*| 79.9| 35.5*| 71.2| 61.7*| 70.8| 280 | | **4B-scale Lite VLMs** | | | | | | | | | | | | | | | | | | 281 | | MobileVLM-v2 | 5.0* | 61.1| 30.8* | 1440.5 | 18.7*| 541.0*| 57.5| 261.8*| 28.3*| 26.1*| 30.8*| 70.0| 53.2*| 15.7*| 63.2| 43.2*| 64.5*| 282 | | Mipha-3B | 16.2* | **63.9**| 34.3*| **1488.9**| 32.0*| 619.0*| 56.6| 285.0*| 27.8*| 33.5*| 35.8*| 70.9| 64.7*| 23.1*| **69.7**| 42.9*| **71.2***| 283 | | TinyLLaVA | 15.6* | 62.1| 37.2* | 1465.5*| 33.3*| 663.5*| **60.3**| 281.1*| 30.3*| 37.5| 38.4| **73.0**| 70.8*| 29.8*| **69.7***| 42.8*| 70.4*| 284 | | **Ours** | | | | | | | | | | | | | | | | | | 285 | | **ALLaVA-Phi2** | 49.4 | 48.8| 24.8 | 1316.2| **36.0**| 632.0| 49.5| 301.8| 27.4| 32.2| 35.3| 67.6| 69.4| 43.6| 64.0| 40.8| 65.2| 286 | | **ALLaVA-StableLM2** | 38.8 | 49.8| 25.3 | 1311.7| 34.0 | 655.2| 51.7| 257.9| 27.7| 31.7| 33.3| 64.7| **72.0**| 39.3| 64.6| 49.8| 65.7| 287 | | **ALLaVA-Phi3** | **56.9**| 52.2| **48.1**| 1382.3| 32.7| **667.8**| 53.0| **347.1**| **32.9**| **37.8**| **41.1**| 64.0| 68.5| **54.8**| 68.1| **55.3**| 69.0| 288 | 289 | 290 | > \* denotes the results of our evaluation. **Bold numbers** are the best results among all 4B-scale LVLMs.The detailed information of each benchmark is shown in Table 4 of our [technical report](https://arxiv.org/pdf/2402.11684.pdf). 291 | 292 | ## 🏭 Inference 293 | All models can be loaded from HuggingFace using `.from_pretrained()` method. 294 | Check out [example scripts](https://github.com/FreedomIntelligence/ALLaVA/tree/main/allava/serve) for sample inputs and outputs. 295 | 296 | 297 | 299 | 300 | 309 | 310 | 321 | 322 | 323 | 324 | 340 | 341 | 342 | 343 | ## 🏋️‍♂️ Training 344 | 345 | ### Data 346 |
347 | training_datasets 348 |
349 | 350 | ALLaVA uses 1.0M and 1.5M data for PT. and FT., respectively. 351 | 352 | ### Code 353 | The training code is largely based on [LLaVA](https://github.com/haotian-liu/LLaVA). 354 | We wholeheartedly express our gratitude for their invaluable contributions to open-sourcing LVLMs. 355 | 356 | 361 | 362 | 363 | ### Hyperparameters 364 | 365 | | Global Batch Size| ZeRO Stage| Optimizer | Max LR| Min LR | Scheduler | Weight decay | 366 | | ---: | ---: |--:| ---: | ---: | ---: | ---: | 367 | | 256 (PT) / 128 (FT) | 1| AdamW | 2e-5 | 2e-6 | CosineAnnealingWarmRestarts | 0 | 368 | 369 | The LM backbone, projector are trainable, while the vision encoder is kept frozen. 370 | **The trainabilities of each module are the same for both stages.** 371 | 372 | 373 | ## 🙌 Contributors 374 | Project Leader: [Guiming Hardy Chen](https://g-h-chen.github.io/) 375 | 376 | Data: Shunian Chen, [Junying Chen](https://jymchen.github.io/), Xiangbo Wu 377 | 378 | Evaluation: [Ruifei Zhang](https://scholar.google.com/citations?user=W4zOhmEAAAAJ&hl=zh-CN) 379 | 380 | Deployment: Xiangbo Wu, Zhiyi Zhang 381 | 382 | Advising: [Zhihong Chen](https://zhjohnchan.github.io/), [Benyou Wang](https://wabyking.github.io/old.html) 383 | 384 | Others: Jianquan Li, [Xiang Wan](https://scholar.google.com/citations?user=e3_kWigAAAAJ&hl=zh-CN) 385 | 386 | 387 | 388 | 389 | ## 📝 Citation 390 | If you find our data useful, please consider citing our work! We are FreedomIntelligence from [Shenzhen Research Institute of Big Data](http://sribd.cn/en) and [The Chinese University of Hong Kong, Shenzhen](https://sds.cuhk.edu.cn/en) 391 | ``` 392 | @misc{chen2024allava, 393 | title={ALLaVA: Harnessing GPT4V-synthesized Data for A Lite Vision-Language Model}, 394 | author={Guiming Hardy Chen and Shunian Chen and Ruifei Zhang and Junying Chen and Xiangbo Wu and Zhiyi Zhang and Zhihong Chen and Jianquan Li and Xiang Wan and Benyou Wang}, 395 | year={2024}, 396 | eprint={2402.11684}, 397 | archivePrefix={arXiv}, 398 | primaryClass={cs.CL} 399 | } 400 | ``` 401 | 402 | ## Star History 403 | 404 | [![Star History Chart](https://api.star-history.com/svg?repos=FreedomIntelligence/ALLaVA&type=Date)](https://star-history.com/#FreedomIntelligence/ALLaVA&Date) 405 | -------------------------------------------------------------------------------- /allava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /allava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | # from .language_model.llava_stablelm_1_6b import LlavaStableLM_1_6bForCausalLM, LlavaStableLM_1_6bConfig 3 | from .language_model.llava_phi import LlavaPhiForCausalLM, LlavaPhiConfig 4 | 5 | import transformers # should be >= 4.37 6 | -------------------------------------------------------------------------------- /allava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /allava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | import pdb 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | if 'llava' in model_name.lower(): 45 | # Load LLaVA model 46 | if 'lora' in model_name.lower() and model_base is None: 47 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 48 | if 'lora' in model_name.lower() and model_base is not None: 49 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 50 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 51 | print('Loading LLaVA from base model...') 52 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 53 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 54 | if model.lm_head.weight.shape[0] != token_num: 55 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 56 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 57 | 58 | print('Loading additional LLaVA weights...') 59 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 60 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 61 | else: 62 | # this is probably from HF Hub 63 | from huggingface_hub import hf_hub_download 64 | def load_from_hf(repo_id, filename, subfolder=None): 65 | cache_file = hf_hub_download( 66 | repo_id=repo_id, 67 | filename=filename, 68 | subfolder=subfolder) 69 | return torch.load(cache_file, map_location='cpu') 70 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 71 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 72 | if any(k.startswith('model.model.') for k in non_lora_trainables): 73 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 74 | model.load_state_dict(non_lora_trainables, strict=False) 75 | 76 | from peft import PeftModel 77 | print('Loading LoRA weights...') 78 | model = PeftModel.from_pretrained(model, model_path) 79 | print('Merging LoRA weights...') 80 | model = model.merge_and_unload() 81 | print('Model is loaded...') 82 | elif model_base is not None: 83 | # this may be mm projector only 84 | print('Loading LLaVA from base model...') 85 | if 'mpt' in model_name.lower(): 86 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 87 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 88 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 89 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 90 | model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 91 | else: 92 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 93 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 94 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 95 | 96 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 97 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 98 | model.load_state_dict(mm_projector_weights, strict=False) 99 | else: 100 | if 'mpt' in model_name.lower(): 101 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 102 | model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 103 | else: 104 | try: 105 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 106 | except: 107 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 108 | model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 109 | else: 110 | # Load language model 111 | if model_base is not None: 112 | # PEFT model 113 | from peft import PeftModel 114 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 115 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 116 | print(f"Loading LoRA weights from {model_path}") 117 | model = PeftModel.from_pretrained(model, model_path) 118 | print(f"Merging weights") 119 | model = model.merge_and_unload() 120 | print('Convert to FP16...') 121 | model.to(torch.float16) 122 | else: 123 | use_fast = False 124 | if 'mpt' in model_name.lower(): 125 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 126 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 127 | else: 128 | try: 129 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 130 | except: 131 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 132 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 133 | 134 | image_processor = None 135 | 136 | if 'llava' in model_name.lower(): 137 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 138 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 139 | if mm_use_im_patch_token: 140 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 141 | if mm_use_im_start_end: 142 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 143 | model.resize_token_embeddings(len(tokenizer)) 144 | 145 | vision_tower = model.get_vision_tower() 146 | if not vision_tower.is_loaded: 147 | vision_tower.load_model() 148 | vision_tower.to(device=device, dtype=torch.float16) 149 | image_processor = vision_tower.image_processor 150 | 151 | if hasattr(model.config, "max_sequence_length"): 152 | context_len = model.config.max_sequence_length 153 | else: 154 | context_len = 2048 155 | 156 | return tokenizer, model, image_processor, context_len 157 | -------------------------------------------------------------------------------- /allava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /allava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | from transformers.cache_utils import Cache, DynamicCache 30 | 31 | class LlavaConfig(LlamaConfig): 32 | model_type = "llava_llama" 33 | 34 | 35 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 36 | config_class = LlavaConfig 37 | 38 | def __init__(self, config: LlamaConfig): 39 | super(LlavaLlamaModel, self).__init__(config) 40 | 41 | 42 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaConfig 44 | 45 | # def __init__(self, config): 46 | # config._flash_attn_2_enabled = True 47 | # super(LlamaForCausalLM, self).__init__(config) 48 | # self.model = LlavaLlamaModel(config) 49 | # self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) 50 | # self.pretraining_tp = config.pretraining_tp 51 | # self.vocab_size = config.vocab_size 52 | # self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 53 | 54 | # # Initialize weights and apply final processing 55 | # self.post_init() 56 | def __init__(self, config, init_vision_encoder_from_ckpt=False): 57 | config._flash_attn_2_enabled = True 58 | config._attn_implementation = 'flash_attention_2' 59 | super(LlamaForCausalLM, self).__init__(config) 60 | self.model = LlavaLlamaModel(config) 61 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) 62 | self.pretraining_tp = config.pretraining_tp 63 | self.vocab_size = config.vocab_size 64 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 65 | 66 | # if getattr(config, 'init_vision_encoder_from_ckpt', True): 67 | if init_vision_encoder_from_ckpt: 68 | vision_tower = self.get_vision_tower() 69 | print(f'loading from CLIP first. This should only be used at inference!!!') 70 | vision_tower.load_model() # 71 | 72 | # Initialize weights and apply final processing 73 | self.post_init() 74 | 75 | def get_model(self): 76 | return self.model 77 | 78 | def get_tokenizer(self): 79 | return self.tokenizer 80 | 81 | def forward( 82 | self, 83 | input_ids: torch.LongTensor = None, 84 | attention_mask: Optional[torch.Tensor] = None, 85 | position_ids: Optional[torch.LongTensor] = None, 86 | past_key_values: Optional[List[torch.FloatTensor]] = None, 87 | inputs_embeds: Optional[torch.FloatTensor] = None, 88 | labels: Optional[torch.LongTensor] = None, 89 | use_cache: Optional[bool] = None, 90 | output_attentions: Optional[bool] = None, 91 | output_hidden_states: Optional[bool] = None, 92 | images: Optional[torch.FloatTensor] = None, 93 | return_dict: Optional[bool] = None, 94 | ) -> Union[Tuple, CausalLMOutputWithPast]: 95 | 96 | if inputs_embeds is None: 97 | ( 98 | input_ids, 99 | position_ids, 100 | attention_mask, 101 | past_key_values, 102 | inputs_embeds, 103 | labels 104 | # ) = self.prepare_inputs_labels_for_multimodal( 105 | ) = self.prepare_inputs_labels_for_multimodal_new( 106 | input_ids, 107 | position_ids, 108 | attention_mask, 109 | past_key_values, 110 | labels, 111 | images 112 | ) 113 | 114 | return super().forward( 115 | input_ids=input_ids, 116 | attention_mask=attention_mask, 117 | position_ids=position_ids, 118 | past_key_values=past_key_values, 119 | inputs_embeds=inputs_embeds, 120 | labels=labels, 121 | use_cache=use_cache, 122 | output_attentions=output_attentions, 123 | output_hidden_states=output_hidden_states, 124 | return_dict=return_dict 125 | ) 126 | 127 | # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 128 | # images = kwargs.pop("images", None) 129 | # _inputs = super().prepare_inputs_for_generation( 130 | # input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 131 | # ) 132 | # if images is not None: 133 | # _inputs['images'] = images 134 | # return _inputs 135 | 136 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): 137 | ''' 138 | This function is called for each token at inference 139 | ''' 140 | # pdb.set_trace() 141 | images = kwargs.pop("images", None) 142 | 143 | #################################################### 144 | # lines from modeling_phi.py 145 | #################################################### 146 | 147 | if past_key_values is not None: 148 | if isinstance(past_key_values, Cache): 149 | cache_length = past_key_values.get_seq_length() 150 | past_length = past_key_values.seen_tokens 151 | max_cache_length = past_key_values.get_max_length() 152 | else: 153 | cache_length = past_length = past_key_values[0][0].shape[2] 154 | max_cache_length = None 155 | 156 | # Keep only the unprocessed tokens: 157 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 158 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 159 | # input) 160 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 161 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 162 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 163 | # input_ids based on the past_length. 164 | elif past_length < input_ids.shape[1]: 165 | input_ids = input_ids[:, past_length:] 166 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 167 | elif past_length >= input_ids.shape[1]: 168 | input_ids = input_ids[:, [-1]] # only keep the last one! 169 | 170 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 171 | if ( 172 | max_cache_length is not None 173 | and attention_mask is not None 174 | and cache_length + input_ids.shape[1] > max_cache_length 175 | ): 176 | attention_mask = attention_mask[:, -max_cache_length:] 177 | 178 | position_ids = kwargs.get("position_ids", None) 179 | if attention_mask is not None and position_ids is None: 180 | # create position_ids on the fly for batch generation 181 | position_ids = attention_mask.long().cumsum(-1) - 1 182 | position_ids.masked_fill_(attention_mask == 0, 1) 183 | if past_key_values: 184 | position_ids = position_ids[:, -input_ids.shape[1] :] 185 | 186 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 187 | if inputs_embeds is not None and past_key_values is None: 188 | model_inputs = {"inputs_embeds": inputs_embeds} 189 | else: 190 | model_inputs = {"input_ids": input_ids} 191 | 192 | model_inputs.update( 193 | { 194 | "position_ids": position_ids, 195 | "past_key_values": past_key_values, 196 | "use_cache": kwargs.get("use_cache"), 197 | "attention_mask": attention_mask, 198 | } 199 | ) 200 | #################################################### 201 | # end of lines from modeling_phi.py 202 | #################################################### 203 | 204 | 205 | if images is not None: 206 | model_inputs['images'] = images 207 | return model_inputs 208 | 209 | 210 | AutoConfig.register("llava_llama", LlavaConfig) 211 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 212 | -------------------------------------------------------------------------------- /allava/model/language_model/llava_phi.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import math 6 | import pdb 7 | from typing import Dict, Any 8 | 9 | from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel 10 | 11 | 12 | from transformers.modeling_outputs import CausalLMOutputWithPast 13 | 14 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 15 | 16 | from transformers.cache_utils import Cache, DynamicCache 17 | 18 | 19 | import sys 20 | from allava.model.language_model.phi.modeling_phi import PhiForCausalLM, PhiModel, PhiConfig 21 | 22 | 23 | 24 | 25 | ################ Phi ############################### 26 | 27 | class LlavaPhiConfig(PhiConfig): 28 | model_type = "llava_phi" 29 | 30 | class LlavaPhiModel(LlavaMetaModel, PhiModel): 31 | config_class = LlavaPhiConfig 32 | 33 | def __init__(self, config: PhiConfig): 34 | super(LlavaPhiModel, self).__init__(config) 35 | 36 | 37 | 38 | class LlavaPhiForCausalLM(PhiForCausalLM, LlavaMetaForCausalLM): 39 | config_class = LlavaPhiConfig 40 | 41 | def __init__(self, config, init_vision_encoder_from_ckpt=False): 42 | config._attn_implementation = "flash_attention_2" 43 | 44 | super(PhiForCausalLM, self).__init__(config) 45 | # self.model is used in LlavaMetaForCausalLM.get_model(); self.transformer is used in PhiForCausalLM.forward() 46 | self.model = LlavaPhiModel(config) 47 | if hasattr(self.model, '_use_flash_attention_2'): 48 | assert self.model._use_flash_attention_2, 'flash attn is not enabled. check it out!' 49 | self.vocab_size = config.vocab_size 50 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 51 | 52 | if init_vision_encoder_from_ckpt: 53 | vision_tower = self.get_vision_tower() 54 | print(f'loading from CLIP first. This should only be used at inference!!!') 55 | vision_tower.load_model() # 56 | 57 | # Initialize weights and apply final processing 58 | self.post_init() 59 | 60 | def get_model(self): 61 | return self.model 62 | 63 | def get_tokenizer(self): 64 | return self.tokenizer 65 | 66 | def forward( 67 | self, 68 | input_ids: torch.LongTensor = None, 69 | attention_mask: Optional[torch.Tensor] = None, 70 | position_ids: Optional[torch.LongTensor] = None, 71 | past_key_values: Optional[List[torch.FloatTensor]] = None, 72 | inputs_embeds: Optional[torch.FloatTensor] = None, 73 | labels: Optional[torch.LongTensor] = None, 74 | use_cache: Optional[bool] = None, 75 | output_attentions: Optional[bool] = None, 76 | output_hidden_states: Optional[bool] = None, 77 | images: Optional[torch.FloatTensor] = None, 78 | return_dict: Optional[bool] = None, 79 | ) -> Union[Tuple, CausalLMOutputWithPast]: 80 | 81 | 82 | if inputs_embeds is None: 83 | ( 84 | input_ids, 85 | position_ids, 86 | attention_mask, 87 | past_key_values, 88 | inputs_embeds, 89 | labels 90 | # ) = self.prepare_inputs_labels_for_multimodal( 91 | ) = self.prepare_inputs_labels_for_multimodal_new( 92 | input_ids, 93 | position_ids, 94 | attention_mask, 95 | past_key_values, 96 | labels, 97 | images 98 | ) 99 | 100 | # pdb.set_trace() 101 | return super().forward( 102 | input_ids=input_ids, 103 | attention_mask=attention_mask, 104 | position_ids=position_ids, 105 | past_key_values=past_key_values, 106 | inputs_embeds=inputs_embeds, 107 | labels=labels, 108 | use_cache=use_cache, 109 | output_attentions=output_attentions, 110 | output_hidden_states=output_hidden_states, 111 | return_dict=return_dict 112 | ) 113 | 114 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): 115 | ''' 116 | This function is called for each token at inference 117 | ''' 118 | # pdb.set_trace() 119 | images = kwargs.pop("images", None) 120 | 121 | #################################################### 122 | # lines from modeling_phi.py 123 | #################################################### 124 | 125 | if past_key_values is not None: 126 | if isinstance(past_key_values, Cache): 127 | cache_length = past_key_values.get_seq_length() 128 | past_length = past_key_values.seen_tokens 129 | max_cache_length = past_key_values.get_max_length() 130 | else: 131 | cache_length = past_length = past_key_values[0][0].shape[2] 132 | max_cache_length = None 133 | 134 | # Keep only the unprocessed tokens: 135 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 136 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 137 | # input) 138 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 139 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 140 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 141 | # input_ids based on the past_length. 142 | elif past_length < input_ids.shape[1]: 143 | input_ids = input_ids[:, past_length:] 144 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 145 | elif past_length >= input_ids.shape[1]: 146 | input_ids = input_ids[:, [-1]] # only keep the last one! 147 | 148 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 149 | if ( 150 | max_cache_length is not None 151 | and attention_mask is not None 152 | and cache_length + input_ids.shape[1] > max_cache_length 153 | ): 154 | attention_mask = attention_mask[:, -max_cache_length:] 155 | 156 | position_ids = kwargs.get("position_ids", None) 157 | if attention_mask is not None and position_ids is None: 158 | # create position_ids on the fly for batch generation 159 | position_ids = attention_mask.long().cumsum(-1) - 1 160 | position_ids.masked_fill_(attention_mask == 0, 1) 161 | if past_key_values: 162 | position_ids = position_ids[:, -input_ids.shape[1] :] 163 | 164 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 165 | if inputs_embeds is not None and past_key_values is None: 166 | model_inputs = {"inputs_embeds": inputs_embeds} 167 | else: 168 | model_inputs = {"input_ids": input_ids} 169 | 170 | model_inputs.update( 171 | { 172 | "position_ids": position_ids, 173 | "past_key_values": past_key_values, 174 | "use_cache": kwargs.get("use_cache"), 175 | "attention_mask": attention_mask, 176 | } 177 | ) 178 | #################################################### 179 | # end of lines from modeling_phi.py 180 | #################################################### 181 | 182 | 183 | if images is not None: 184 | model_inputs['images'] = images 185 | return model_inputs 186 | 187 | 188 | # def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 189 | # ''' 190 | # This function is called for each token at inference 191 | # ''' 192 | # pdb.set_trace() 193 | # images = kwargs.pop("images", None) 194 | 195 | 196 | # _inputs = super().prepare_inputs_for_generation( 197 | # input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 198 | # ) 199 | # if images is not None: 200 | # _inputs['images'] = images 201 | # return _inputs 202 | 203 | 204 | AutoConfig.register("llava_phi", LlavaPhiConfig) 205 | AutoModelForCausalLM.register(LlavaPhiConfig, LlavaPhiForCausalLM) -------------------------------------------------------------------------------- /allava/model/language_model/llava_stablelm_1_6b.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | import warnings 18 | 19 | import torch 20 | import torch.nn as nn 21 | 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM, AutoModel, PretrainedConfig 24 | # StableLMEpochConfig, StableLMEpochModel, StableLMEpochForCausalLM 25 | from transformers.modeling_utils import cached_file, CONFIG_NAME, extract_commit_hash, is_peft_available, find_adapter_config_file, json, os 26 | from transformers.models.auto.auto_factory import _BaseAutoModelClass, _get_model_class 27 | from transformers.dynamic_module_utils import resolve_trust_remote_code, get_class_from_dynamic_module 28 | 29 | 30 | from transformers.modeling_outputs import CausalLMOutputWithPast 31 | 32 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 33 | import pdb 34 | 35 | import sys 36 | sys.path.insert(0, '/mntcephfs/data/med/guimingchen/models/general/stablelm-2-1_6b') 37 | from modeling_stablelm_epoch import StableLMEpochForCausalLM, StableLMEpochModel, StableLMEpochConfig 38 | 39 | 40 | ################ stableLM ############################### 41 | 42 | class LlavaStableLM_1_6bConfig(StableLMEpochConfig): 43 | model_type = "llava_stablelm_1_6b" 44 | 45 | # class LlavaStableLMModel(LlavaMetaModel, AutoModel): 46 | class LlavaStableLMModel(LlavaMetaModel, StableLMEpochModel): 47 | config_class = LlavaStableLM_1_6bConfig 48 | 49 | def __init__(self, config: AutoConfig): 50 | super(LlavaStableLMModel, self).__init__(config) 51 | 52 | 53 | 54 | class LlavaStableLM_1_6bForCausalLM(StableLMEpochForCausalLM, LlavaMetaForCausalLM): 55 | config_class = LlavaStableLM_1_6bConfig 56 | 57 | 58 | def __init__(self, config, init_vision_encoder_from_ckpt=False): 59 | config._attn_implementation = "flash_attention_2" 60 | 61 | super(StableLMEpochForCausalLM, self).__init__(config) 62 | 63 | self.model = LlavaStableLMModel(config) 64 | if hasattr(self.model, '_use_flash_attention_2'): 65 | assert self.model._use_flash_attention_2, 'flash attn is not enabled. check it out!' 66 | # self.pretraining_tp = config.pretraining_tp 67 | self.vocab_size = config.vocab_size 68 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 69 | 70 | if init_vision_encoder_from_ckpt: 71 | vision_tower = self.get_vision_tower() 72 | print(f'loading from CLIP first. This should only be used at inference!!!') 73 | vision_tower.load_model() # 74 | 75 | # Initialize weights and apply final processing 76 | self.post_init() 77 | 78 | 79 | def get_model(self): 80 | return self.model 81 | 82 | def get_tokenizer(self): 83 | return self.tokenizer 84 | 85 | def forward( 86 | self, 87 | input_ids: torch.LongTensor = None, 88 | attention_mask: Optional[torch.Tensor] = None, 89 | position_ids: Optional[torch.LongTensor] = None, 90 | past_key_values: Optional[List[torch.FloatTensor]] = None, 91 | inputs_embeds: Optional[torch.FloatTensor] = None, 92 | labels: Optional[torch.LongTensor] = None, 93 | use_cache: Optional[bool] = None, 94 | output_attentions: Optional[bool] = None, 95 | output_hidden_states: Optional[bool] = None, 96 | images: Optional[torch.FloatTensor] = None, 97 | return_dict: Optional[bool] = None, 98 | ) -> Union[Tuple, CausalLMOutputWithPast]: 99 | 100 | if inputs_embeds is None: 101 | ( 102 | input_ids, 103 | position_ids, 104 | attention_mask, 105 | past_key_values, 106 | inputs_embeds, 107 | labels 108 | # ) = self.prepare_inputs_labels_for_multimodal( 109 | ) = self.prepare_inputs_labels_for_multimodal_new( 110 | input_ids, 111 | position_ids, 112 | attention_mask, 113 | past_key_values, 114 | labels, 115 | images 116 | ) 117 | return super().forward( 118 | input_ids=input_ids, 119 | attention_mask=attention_mask, 120 | position_ids=position_ids, 121 | past_key_values=past_key_values, 122 | inputs_embeds=inputs_embeds, 123 | labels=labels, 124 | use_cache=use_cache, 125 | output_attentions=output_attentions, 126 | output_hidden_states=output_hidden_states, 127 | return_dict=return_dict 128 | ) 129 | 130 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 131 | images = kwargs.pop("images", None) 132 | _inputs = super().prepare_inputs_for_generation( 133 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 134 | ) 135 | if images is not None: 136 | _inputs['images'] = images 137 | return _inputs 138 | 139 | # class StableLMEpochConfig = AutoConfig.from_pretrained('/wangbenyou/guimingchen/models/stablelm-3b-4e1t', trust_remote_code=True) 140 | 141 | 142 | AutoConfig.register("llava_stablelm_1_6b", LlavaStableLM_1_6bConfig) 143 | # AutoConfig.register("stablelm_epoch", LlavaStableLMConfig) 144 | AutoModelForCausalLM.register(LlavaStableLM_1_6bConfig, LlavaStableLM_1_6bForCausalLM) 145 | -------------------------------------------------------------------------------- /allava/model/language_model/llava_vicuna.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaVicunaConfig(LlamaConfig): 31 | model_type = "llava_vicuna" 32 | 33 | 34 | class LlavaVicunaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaVicunaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaVicunaModel, self).__init__(config) 39 | 40 | 41 | class LlavaVicunaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaVicunaConfig 43 | 44 | def __init__(self, config): 45 | config._flash_attn_2_enabled = True 46 | super(LlamaForCausalLM, self).__init__(config) 47 | self.model = LlavaVicunaModel(config) 48 | self.tokenizer = AutoTokenizer.from_pretrained(config._name_or_path) 49 | self.pretraining_tp = config.pretraining_tp 50 | self.vocab_size = config.vocab_size 51 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 52 | 53 | # Initialize weights and apply final processing 54 | self.post_init() 55 | 56 | def get_model(self): 57 | return self.model 58 | 59 | def get_tokenizer(self): 60 | return self.tokenizer 61 | 62 | def forward( 63 | self, 64 | input_ids: torch.LongTensor = None, 65 | attention_mask: Optional[torch.Tensor] = None, 66 | position_ids: Optional[torch.LongTensor] = None, 67 | past_key_values: Optional[List[torch.FloatTensor]] = None, 68 | inputs_embeds: Optional[torch.FloatTensor] = None, 69 | labels: Optional[torch.LongTensor] = None, 70 | use_cache: Optional[bool] = None, 71 | output_attentions: Optional[bool] = None, 72 | output_hidden_states: Optional[bool] = None, 73 | images: Optional[torch.FloatTensor] = None, 74 | return_dict: Optional[bool] = None, 75 | ) -> Union[Tuple, CausalLMOutputWithPast]: 76 | 77 | if inputs_embeds is None: 78 | ( 79 | input_ids, 80 | position_ids, 81 | attention_mask, 82 | past_key_values, 83 | inputs_embeds, 84 | labels 85 | # ) = self.prepare_inputs_labels_for_multimodal( 86 | ) = self.prepare_inputs_labels_for_multimodal_new( 87 | input_ids, 88 | position_ids, 89 | attention_mask, 90 | past_key_values, 91 | labels, 92 | images 93 | ) 94 | 95 | return super().forward( 96 | input_ids=input_ids, 97 | attention_mask=attention_mask, 98 | position_ids=position_ids, 99 | past_key_values=past_key_values, 100 | inputs_embeds=inputs_embeds, 101 | labels=labels, 102 | use_cache=use_cache, 103 | output_attentions=output_attentions, 104 | output_hidden_states=output_hidden_states, 105 | return_dict=return_dict 106 | ) 107 | 108 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 109 | images = kwargs.pop("images", None) 110 | _inputs = super().prepare_inputs_for_generation( 111 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 112 | ) 113 | if images is not None: 114 | _inputs['images'] = images 115 | return _inputs 116 | 117 | AutoConfig.register("llava_vicuna", LlavaVicunaConfig) 118 | AutoModelForCausalLM.register(LlavaVicunaConfig, LlavaVicunaForCausalLM) 119 | -------------------------------------------------------------------------------- /allava/model/language_model/phi/configuration_phi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ Phi model configuration""" 17 | 18 | 19 | from transformers.configuration_utils import PretrainedConfig 20 | from transformers.utils import logging 21 | 22 | 23 | logger = logging.get_logger(__name__) 24 | 25 | PHI_PRETRAINED_CONFIG_ARCHIVE_MAP = { 26 | "microsoft/phi-2": "https://huggingface.co/microsoft/phi-2/resolve/main/config.json", 27 | } 28 | 29 | 30 | class PhiConfig(PretrainedConfig): 31 | r""" 32 | This is the configuration class to store the configuration of a [`PhiModel`]. It is used to instantiate an Phi 33 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 34 | defaults will yield a similar configuration to that of the Phi 35 | [microsoft/phi-1](https://huggingface.co/microsoft/phi-1). 36 | 37 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 38 | documentation from [`PretrainedConfig`] for more information. 39 | 40 | Args: 41 | vocab_size (`int`, *optional*, defaults to 51200): 42 | Vocabulary size of the Phi model. Defines the number of different tokens that can be represented by the 43 | `inputs_ids` passed when calling [`PhiModel`]. 44 | hidden_size (`int`, *optional*, defaults to 2048): 45 | Dimension of the hidden representations. 46 | intermediate_size (`int`, *optional*, defaults to 8192): 47 | Dimension of the MLP representations. 48 | num_hidden_layers (`int`, *optional*, defaults to 24): 49 | Number of hidden layers in the Transformer decoder. 50 | num_attention_heads (`int`, *optional*, defaults to 32): 51 | Number of attention heads for each attention layer in the Transformer decoder. 52 | num_key_value_heads (`int`, *optional*): 53 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 54 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 55 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 56 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 57 | by meanpooling all the original heads within that group. For more details checkout [this 58 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 59 | `num_attention_heads`. 60 | resid_pdrop (`float`, *optional*, defaults to 0.0): 61 | Dropout probability for mlp outputs. 62 | embd_pdrop (`int`, *optional*, defaults to 0.0): 63 | The dropout ratio for the embeddings. 64 | attention_dropout (`float`, *optional*, defaults to 0.0): 65 | The dropout ratio after computing the attention scores. 66 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu_new"`): 67 | The non-linear activation function (function or string) in the decoder. 68 | max_position_embeddings (`int`, *optional*, defaults to 2048): 69 | The maximum sequence length that this model might ever be used with. Phi-1 and Phi-1.5 supports up to 2048 70 | tokens. 71 | initializer_range (`float`, *optional*, defaults to 0.02): 72 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 73 | layer_norm_eps (`float`, *optional*, defaults to 1e-05): 74 | The epsilon used by the rms normalization layers. 75 | use_cache (`bool`, *optional*, defaults to `True`): 76 | Whether or not the model should return the last key/values attentions (not used by all models). Only 77 | relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. 78 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 79 | Whether to tie weight embeddings 80 | rope_theta (`float`, *optional*, defaults to 10000.0): 81 | The base period of the RoPE embeddings. 82 | rope_scaling (`Dict`, *optional*): 83 | Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling 84 | strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format 85 | is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update 86 | `max_position_embeddings` to the expected new maximum. See the following thread for more information on how 87 | these scaling strategies behave: 88 | https://www.reddit.com/r/LocalPersimmon/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This 89 | is an experimental feature, subject to breaking API changes in future versions. 90 | partial_rotary_factor (`float`, *optional*, defaults to 0.5): 91 | Percentage of the query and keys which will have rotary embedding. 92 | qk_layernorm (`bool`, *optional*, defaults to `False`): 93 | Whether or not to normalize the Queries and Keys after projecting the hidden states. 94 | bos_token_id (`int`, *optional*, defaults to 1): 95 | Denotes beginning of sequences token id. 96 | eos_token_id (`int`, *optional*, defaults to 2): 97 | Denotes end of sequences token id. 98 | 99 | Example: 100 | 101 | ```python 102 | >>> from transformers import PhiModel, PhiConfig 103 | 104 | >>> # Initializing a Phi-1 style configuration 105 | >>> configuration = PhiConfig.from_pretrained("microsoft/phi-1") 106 | 107 | >>> # Initializing a model from the configuration 108 | >>> model = PhiModel(configuration) 109 | 110 | >>> # Accessing the model configuration 111 | >>> configuration = model.config 112 | ```""" 113 | 114 | model_type = "phi" 115 | keys_to_ignore_at_inference = ["past_key_values"] 116 | 117 | def __init__( 118 | self, 119 | vocab_size=51200, 120 | hidden_size=2048, 121 | intermediate_size=8192, 122 | num_hidden_layers=24, 123 | num_attention_heads=32, 124 | num_key_value_heads=None, 125 | resid_pdrop=0.0, 126 | embd_pdrop=0.0, 127 | attention_dropout=0.0, 128 | hidden_act="gelu_new", 129 | max_position_embeddings=2048, 130 | initializer_range=0.02, 131 | layer_norm_eps=1e-5, 132 | use_cache=True, 133 | tie_word_embeddings=False, 134 | rope_theta=10000.0, 135 | rope_scaling=None, 136 | partial_rotary_factor=0.5, 137 | qk_layernorm=False, 138 | bos_token_id=1, 139 | eos_token_id=2, 140 | **kwargs, 141 | ): 142 | self.vocab_size = vocab_size 143 | self.hidden_size = hidden_size 144 | self.intermediate_size = intermediate_size 145 | self.num_hidden_layers = num_hidden_layers 146 | self.num_attention_heads = num_attention_heads 147 | 148 | if num_key_value_heads is None: 149 | num_key_value_heads = num_attention_heads 150 | 151 | self.num_key_value_heads = num_key_value_heads 152 | self.resid_pdrop = resid_pdrop 153 | self.embd_pdrop = embd_pdrop 154 | self.attention_dropout = attention_dropout 155 | self.hidden_act = hidden_act 156 | self.max_position_embeddings = max_position_embeddings 157 | self.initializer_range = initializer_range 158 | self.layer_norm_eps = layer_norm_eps 159 | self.use_cache = use_cache 160 | self.rope_theta = rope_theta 161 | self.rope_scaling = rope_scaling 162 | self.partial_rotary_factor = partial_rotary_factor 163 | self.qk_layernorm = qk_layernorm 164 | self._rope_scaling_validation() 165 | 166 | super().__init__( 167 | bos_token_id=bos_token_id, 168 | eos_token_id=eos_token_id, 169 | tie_word_embeddings=tie_word_embeddings, 170 | **kwargs, 171 | ) 172 | 173 | # Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation 174 | def _rope_scaling_validation(self): 175 | """ 176 | Validate the `rope_scaling` configuration. 177 | """ 178 | if self.rope_scaling is None: 179 | return 180 | 181 | if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: 182 | raise ValueError( 183 | "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " 184 | f"got {self.rope_scaling}" 185 | ) 186 | rope_scaling_type = self.rope_scaling.get("type", None) 187 | rope_scaling_factor = self.rope_scaling.get("factor", None) 188 | if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: 189 | raise ValueError( 190 | f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" 191 | ) 192 | if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: 193 | raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") 194 | -------------------------------------------------------------------------------- /allava/model/language_model/phi/modeling_phi.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Microsoft and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """ PyTorch Phi model.""" 17 | 18 | 19 | import math 20 | from typing import List, Optional, Tuple, Union 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | import torch.utils.checkpoint 25 | from torch import nn 26 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 27 | 28 | from transformers.activations import ACT2FN 29 | # try: 30 | from transformers.cache_utils import Cache, DynamicCache 31 | # except: 32 | # Cache, DynamicCache = None, None 33 | from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask 34 | from transformers.modeling_outputs import ( 35 | BaseModelOutputWithPast, 36 | CausalLMOutputWithPast, 37 | SequenceClassifierOutputWithPast, 38 | TokenClassifierOutput, 39 | ) 40 | from transformers.modeling_utils import PreTrainedModel 41 | from transformers.utils import ( 42 | add_code_sample_docstrings, 43 | add_start_docstrings, 44 | add_start_docstrings_to_model_forward, 45 | is_flash_attn_2_available, 46 | is_flash_attn_greater_or_equal_2_10, # dbg 47 | logging, 48 | replace_return_docstrings, 49 | ) 50 | 51 | # is_flash_attn_greater_or_equal_2_10 = lambda:True # dbg 52 | 53 | try: 54 | from configuration_phi import PhiConfig 55 | except: 56 | from .configuration_phi import PhiConfig 57 | 58 | try: 59 | from flash_attn import flash_attn_func, flash_attn_varlen_func 60 | from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa 61 | except: 62 | pass 63 | 64 | import pdb 65 | 66 | logger = logging.get_logger(__name__) 67 | 68 | _CHECKPOINT_FOR_DOC = "microsoft/phi-2" 69 | _CONFIG_FOR_DOC = "PhiConfig" 70 | 71 | PHI_PRETRAINED_MODEL_ARCHIVE_LIST = [ 72 | "microsoft/phi-2", 73 | # See all Phi models at https://huggingface.co/models?filter=phi 74 | ] 75 | 76 | 77 | # Copied from transformers.models.llama.modeling_llama._get_unpad_data 78 | def _get_unpad_data(attention_mask): 79 | seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) 80 | indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() 81 | max_seqlen_in_batch = seqlens_in_batch.max().item() 82 | cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) 83 | return ( 84 | indices, 85 | cu_seqlens, 86 | max_seqlen_in_batch, 87 | ) 88 | 89 | 90 | # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi 91 | class PhiRotaryEmbedding(nn.Module): 92 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): 93 | super().__init__() 94 | 95 | self.dim = dim 96 | self.max_position_embeddings = max_position_embeddings 97 | self.base = base 98 | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 99 | self.register_buffer("inv_freq", inv_freq, persistent=False) 100 | 101 | # Build here to make `torch.jit.trace` work. 102 | self._set_cos_sin_cache( 103 | seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() 104 | ) 105 | 106 | def _set_cos_sin_cache(self, seq_len, device, dtype): 107 | self.max_seq_len_cached = seq_len 108 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 109 | 110 | freqs = torch.outer(t, self.inv_freq) 111 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 112 | emb = torch.cat((freqs, freqs), dim=-1) 113 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 114 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 115 | 116 | def forward(self, x, seq_len=None): 117 | # x: [bs, num_attention_heads, seq_len, head_size] 118 | if seq_len > self.max_seq_len_cached: 119 | self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) 120 | 121 | return ( 122 | self.cos_cached[:seq_len].to(dtype=x.dtype), 123 | self.sin_cached[:seq_len].to(dtype=x.dtype), 124 | ) 125 | 126 | 127 | # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi 128 | class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): 129 | """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" 130 | 131 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 132 | self.scaling_factor = scaling_factor 133 | super().__init__(dim, max_position_embeddings, base, device) 134 | 135 | def _set_cos_sin_cache(self, seq_len, device, dtype): 136 | self.max_seq_len_cached = seq_len 137 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 138 | t = t / self.scaling_factor 139 | 140 | freqs = torch.outer(t, self.inv_freq) 141 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 142 | emb = torch.cat((freqs, freqs), dim=-1) 143 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 144 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 145 | 146 | 147 | # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi 148 | class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): 149 | """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" 150 | 151 | def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): 152 | self.scaling_factor = scaling_factor 153 | super().__init__(dim, max_position_embeddings, base, device) 154 | 155 | def _set_cos_sin_cache(self, seq_len, device, dtype): 156 | self.max_seq_len_cached = seq_len 157 | 158 | if seq_len > self.max_position_embeddings: 159 | base = self.base * ( 160 | (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) 161 | ) ** (self.dim / (self.dim - 2)) 162 | inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 163 | self.register_buffer("inv_freq", inv_freq, persistent=False) 164 | 165 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) 166 | 167 | freqs = torch.outer(t, self.inv_freq) 168 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 169 | emb = torch.cat((freqs, freqs), dim=-1) 170 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 171 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 172 | 173 | 174 | # Copied from transformers.models.llama.modeling_llama.rotate_half 175 | def rotate_half(x): 176 | """Rotates half the hidden dims of the input.""" 177 | x1 = x[..., : x.shape[-1] // 2] 178 | x2 = x[..., x.shape[-1] // 2 :] 179 | return torch.cat((-x2, x1), dim=-1) 180 | 181 | 182 | # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb 183 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): 184 | """Applies Rotary Position Embedding to the query and key tensors. 185 | 186 | Args: 187 | q (`torch.Tensor`): The query tensor. 188 | k (`torch.Tensor`): The key tensor. 189 | cos (`torch.Tensor`): The cosine part of the rotary embedding. 190 | sin (`torch.Tensor`): The sine part of the rotary embedding. 191 | position_ids (`torch.Tensor`): 192 | The position indices of the tokens corresponding to the query and key tensors. For example, this can be 193 | used to pass offsetted position ids when working with a KV-cache. 194 | unsqueeze_dim (`int`, *optional*, defaults to 1): 195 | The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and 196 | sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note 197 | that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and 198 | k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes 199 | cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have 200 | the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. 201 | Returns: 202 | `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. 203 | """ 204 | cos = cos[position_ids].unsqueeze(unsqueeze_dim) 205 | sin = sin[position_ids].unsqueeze(unsqueeze_dim) 206 | q_embed = (q * cos) + (rotate_half(q) * sin) 207 | k_embed = (k * cos) + (rotate_half(k) * sin) 208 | return q_embed, k_embed 209 | 210 | 211 | # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Phi 212 | class PhiMLP(nn.Module): 213 | def __init__(self, config): 214 | super().__init__() 215 | self.config = config 216 | self.activation_fn = ACT2FN[config.hidden_act] 217 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 218 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 219 | 220 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 221 | hidden_states = self.fc1(hidden_states) 222 | hidden_states = self.activation_fn(hidden_states) 223 | hidden_states = self.fc2(hidden_states) 224 | return hidden_states 225 | 226 | 227 | # Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi 228 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: 229 | """ 230 | This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, 231 | num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) 232 | """ 233 | batch, num_key_value_heads, slen, head_dim = hidden_states.shape 234 | if n_rep == 1: 235 | return hidden_states 236 | hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) 237 | return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) 238 | 239 | 240 | class PhiAttention(nn.Module): 241 | """Multi-headed attention from 'Attention Is All You Need' paper""" 242 | 243 | def __init__(self, config: PhiConfig, layer_idx: Optional[int] = None): 244 | super().__init__() 245 | self.config = config 246 | self.layer_idx = layer_idx 247 | if layer_idx is None: 248 | logger.warning_once( 249 | f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " 250 | "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " 251 | "when creating this class." 252 | ) 253 | 254 | self.attention_dropout = config.attention_dropout 255 | self.hidden_size = config.hidden_size 256 | self.num_heads = config.num_attention_heads 257 | self.head_dim = self.hidden_size // self.num_heads 258 | self.num_key_value_heads = config.num_key_value_heads 259 | self.num_key_value_groups = self.num_heads // self.num_key_value_heads 260 | self.max_position_embeddings = config.max_position_embeddings 261 | self.rope_theta = config.rope_theta 262 | self.partial_rotary_factor = config.partial_rotary_factor 263 | self.is_causal = True 264 | 265 | if (self.head_dim * self.num_heads) != self.hidden_size: 266 | raise ValueError( 267 | f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" 268 | f" and `num_heads`: {self.num_heads})." 269 | ) 270 | 271 | self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) 272 | self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) 273 | self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) 274 | self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=True) 275 | 276 | self.qk_layernorm = config.qk_layernorm 277 | if self.qk_layernorm: 278 | self.q_layernorm = nn.LayerNorm( 279 | config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True 280 | ) 281 | self.k_layernorm = nn.LayerNorm( 282 | config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True 283 | ) 284 | 285 | self._init_rope() 286 | 287 | def _init_rope(self): 288 | if self.config.rope_scaling is None: 289 | self.rotary_emb = PhiRotaryEmbedding( 290 | int(self.partial_rotary_factor * self.head_dim), 291 | max_position_embeddings=self.max_position_embeddings, 292 | base=self.rope_theta, 293 | ) 294 | else: 295 | scaling_type = self.config.rope_scaling["type"] 296 | scaling_factor = self.config.rope_scaling["factor"] 297 | if scaling_type == "linear": 298 | self.rotary_emb = PhiLinearScalingRotaryEmbedding( 299 | int(self.partial_rotary_factor * self.head_dim), 300 | max_position_embeddings=self.max_position_embeddings, 301 | scaling_factor=scaling_factor, 302 | base=self.rope_theta, 303 | ) 304 | elif scaling_type == "dynamic": 305 | self.rotary_emb = PhiDynamicNTKScalingRotaryEmbedding( 306 | int(self.partial_rotary_factor * self.head_dim), 307 | max_position_embeddings=self.max_position_embeddings, 308 | scaling_factor=scaling_factor, 309 | base=self.rope_theta, 310 | ) 311 | else: 312 | raise ValueError(f"Unknown RoPE scaling type {scaling_type}") 313 | 314 | # Phi-2 has an attention overflow issue (with FP16) and requires autocast to be disabled 315 | @torch.autocast("cpu", enabled=False) 316 | @torch.autocast("cuda", enabled=False) 317 | def forward( 318 | self, 319 | hidden_states: torch.Tensor, 320 | attention_mask: Optional[torch.Tensor] = None, 321 | position_ids: Optional[torch.LongTensor] = None, 322 | past_key_value: Optional[Cache] = None, 323 | output_attentions: bool = False, 324 | use_cache: bool = False, 325 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 326 | bsz, q_len, _ = hidden_states.size() 327 | 328 | query_states = self.q_proj(hidden_states) 329 | key_states = self.k_proj(hidden_states) 330 | value_states = self.v_proj(hidden_states) 331 | 332 | if self.qk_layernorm: 333 | query_states = self.q_layernorm(query_states) 334 | key_states = self.k_layernorm(key_states) 335 | 336 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 337 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 338 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 339 | 340 | kv_seq_len = key_states.shape[-2] 341 | if past_key_value is not None: 342 | if self.layer_idx is None: 343 | raise ValueError( 344 | f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " 345 | "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " 346 | "with a layer index." 347 | ) 348 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 349 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 350 | 351 | # Partial rotary embedding 352 | query_rot, query_pass = ( 353 | query_states[..., : self.rotary_emb.dim], 354 | query_states[..., self.rotary_emb.dim :], 355 | ) 356 | key_rot, key_pass = ( 357 | key_states[..., : self.rotary_emb.dim], 358 | key_states[..., self.rotary_emb.dim :], 359 | ) 360 | # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] 361 | query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) 362 | 363 | # [batch_size, seq_length, num_heads, head_dim] 364 | query_states = torch.cat((query_rot, query_pass), dim=-1) 365 | key_states = torch.cat((key_rot, key_pass), dim=-1) 366 | 367 | if past_key_value is not None: 368 | cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} 369 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 370 | 371 | key_states = repeat_kv(key_states, self.num_key_value_groups) 372 | value_states = repeat_kv(value_states, self.num_key_value_groups) 373 | 374 | # Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow 375 | attn_weights = torch.matmul( 376 | query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) 377 | ) / math.sqrt(self.head_dim) 378 | 379 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 380 | raise ValueError( 381 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 382 | f" {attn_weights.size()}" 383 | ) 384 | 385 | if attention_mask is not None: 386 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 387 | raise ValueError( 388 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 389 | ) 390 | attn_weights = attn_weights + attention_mask 391 | 392 | # upcast attention to fp32 393 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) 394 | attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) 395 | 396 | attn_output = torch.matmul(attn_weights, value_states) 397 | 398 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 399 | raise ValueError( 400 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 401 | f" {attn_output.size()}" 402 | ) 403 | 404 | attn_output = attn_output.transpose(1, 2).contiguous() 405 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 406 | 407 | attn_output = self.dense(attn_output) 408 | 409 | if not output_attentions: 410 | attn_weights = None 411 | 412 | return attn_output, attn_weights, past_key_value 413 | 414 | 415 | class PhiFlashAttention2(PhiAttention): 416 | """ 417 | Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays 418 | untouched. The only required change would be on the forward pass where it needs to correctly call the public API of 419 | flash attention and deal with padding tokens in case the input contains any of them. 420 | """ 421 | 422 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ 423 | def __init__(self, *args, **kwargs): 424 | super().__init__(*args, **kwargs) 425 | 426 | # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. 427 | # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. 428 | # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). 429 | self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() 430 | 431 | def forward( 432 | self, 433 | hidden_states: torch.Tensor, 434 | attention_mask: Optional[torch.LongTensor] = None, 435 | position_ids: Optional[torch.LongTensor] = None, 436 | past_key_value: Optional[Cache] = None, 437 | output_attentions: bool = False, 438 | use_cache: bool = False, 439 | **kwargs, 440 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 441 | # PhiFlashAttention2 attention does not support output_attentions 442 | 443 | output_attentions = False 444 | 445 | bsz, q_len, _ = hidden_states.size() 446 | 447 | query_states = self.q_proj(hidden_states) 448 | key_states = self.k_proj(hidden_states) 449 | value_states = self.v_proj(hidden_states) 450 | 451 | if self.qk_layernorm: 452 | query_states = self.q_layernorm(query_states) 453 | key_states = self.k_layernorm(key_states) 454 | 455 | # Flash attention requires the input to have the shape 456 | # batch_size x seq_length x head_dim x hidden_dim 457 | # therefore we just need to keep the original shape 458 | query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 459 | key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 460 | value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) 461 | 462 | kv_seq_len = key_states.shape[-2] 463 | if past_key_value is not None: 464 | kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) 465 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 466 | 467 | # Partial rotary embedding 468 | query_rot, query_pass = ( 469 | query_states[..., : self.rotary_emb.dim], 470 | query_states[..., self.rotary_emb.dim :], 471 | ) 472 | key_rot, key_pass = ( 473 | key_states[..., : self.rotary_emb.dim], 474 | key_states[..., self.rotary_emb.dim :], 475 | ) 476 | # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] 477 | query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) 478 | 479 | # [batch_size, seq_length, num_heads, head_dim] 480 | query_states = torch.cat((query_rot, query_pass), dim=-1) 481 | key_states = torch.cat((key_rot, key_pass), dim=-1) 482 | 483 | if past_key_value is not None: 484 | cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} 485 | key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) 486 | 487 | # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache 488 | # to be able to avoid many of these transpose/reshape/view. 489 | query_states = query_states.transpose(1, 2) 490 | key_states = key_states.transpose(1, 2) 491 | value_states = value_states.transpose(1, 2) 492 | 493 | attn_dropout = self.attention_dropout if self.training else 0.0 494 | 495 | # In PEFT, usually we cast the layer norms in float32 for training stability reasons 496 | # therefore the input hidden states gets silently casted in float32. Hence, we need 497 | # cast them back in the correct dtype just to be sure everything works as expected. 498 | # This might slowdown training & inference so it is recommended to not cast the LayerNorms 499 | # in fp32. 500 | 501 | if query_states.dtype == torch.float32: 502 | if torch.is_autocast_enabled(): 503 | target_dtype = torch.get_autocast_gpu_dtype() 504 | # Handle the case where the model is quantized 505 | elif hasattr(self.config, "_pre_quantization_dtype"): 506 | target_dtype = self.config._pre_quantization_dtype 507 | else: 508 | target_dtype = self.q_proj.weight.dtype 509 | 510 | logger.warning_once( 511 | f"The input hidden states seems to be silently casted in float32, this might be related to" 512 | f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" 513 | f" {target_dtype}." 514 | ) 515 | 516 | query_states = query_states.to(target_dtype) 517 | key_states = key_states.to(target_dtype) 518 | value_states = value_states.to(target_dtype) 519 | 520 | attn_output = self._flash_attention_forward( 521 | query_states, key_states, value_states, attention_mask, q_len, dropout=attn_dropout, softmax_scale=None 522 | ) 523 | 524 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() 525 | attn_output = self.dense(attn_output) 526 | 527 | if not output_attentions: 528 | attn_weights = None 529 | 530 | return attn_output, attn_weights, past_key_value 531 | 532 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward 533 | def _flash_attention_forward( 534 | self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None 535 | ): 536 | """ 537 | Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token 538 | first unpad the input, then computes the attention scores and pad the final attention scores. 539 | 540 | Args: 541 | query_states (`torch.Tensor`): 542 | Input query states to be passed to Flash Attention API 543 | key_states (`torch.Tensor`): 544 | Input key states to be passed to Flash Attention API 545 | value_states (`torch.Tensor`): 546 | Input value states to be passed to Flash Attention API 547 | attention_mask (`torch.Tensor`): 548 | The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the 549 | position of padding tokens and 1 for the position of non-padding tokens. 550 | dropout (`int`, *optional*): 551 | Attention dropout 552 | softmax_scale (`float`, *optional*): 553 | The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) 554 | """ 555 | if not self._flash_attn_uses_top_left_mask: 556 | causal = self.is_causal 557 | else: 558 | # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. 559 | causal = self.is_causal and query_length != 1 560 | 561 | # Contains at least one padding token in the sequence 562 | if attention_mask is not None: 563 | batch_size = query_states.shape[0] 564 | query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( 565 | query_states, key_states, value_states, attention_mask, query_length 566 | ) 567 | 568 | cu_seqlens_q, cu_seqlens_k = cu_seq_lens 569 | max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens 570 | 571 | attn_output_unpad = flash_attn_varlen_func( 572 | query_states, 573 | key_states, 574 | value_states, 575 | cu_seqlens_q=cu_seqlens_q, 576 | cu_seqlens_k=cu_seqlens_k, 577 | max_seqlen_q=max_seqlen_in_batch_q, 578 | max_seqlen_k=max_seqlen_in_batch_k, 579 | dropout_p=dropout, 580 | softmax_scale=softmax_scale, 581 | causal=causal, 582 | ) 583 | 584 | attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) 585 | else: 586 | attn_output = flash_attn_func( 587 | query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal 588 | ) 589 | 590 | return attn_output 591 | 592 | # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input 593 | def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): 594 | indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) 595 | batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape 596 | 597 | key_layer = index_first_axis( 598 | key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 599 | ) 600 | value_layer = index_first_axis( 601 | value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k 602 | ) 603 | if query_length == kv_seq_len: 604 | query_layer = index_first_axis( 605 | query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k 606 | ) 607 | cu_seqlens_q = cu_seqlens_k 608 | max_seqlen_in_batch_q = max_seqlen_in_batch_k 609 | indices_q = indices_k 610 | elif query_length == 1: 611 | max_seqlen_in_batch_q = 1 612 | cu_seqlens_q = torch.arange( 613 | batch_size + 1, dtype=torch.int32, device=query_layer.device 614 | ) # There is a memcpy here, that is very bad. 615 | indices_q = cu_seqlens_q[:-1] 616 | query_layer = query_layer.squeeze(1) 617 | else: 618 | # The -q_len: slice assumes left padding. 619 | attention_mask = attention_mask[:, -query_length:] 620 | query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) 621 | 622 | return ( 623 | query_layer, 624 | key_layer, 625 | value_layer, 626 | indices_q, 627 | (cu_seqlens_q, cu_seqlens_k), 628 | (max_seqlen_in_batch_q, max_seqlen_in_batch_k), 629 | ) 630 | 631 | 632 | PHI_ATTENTION_CLASSES = { 633 | "eager": PhiAttention, 634 | "flash_attention_2": PhiFlashAttention2, 635 | } 636 | 637 | 638 | class PhiDecoderLayer(nn.Module): 639 | def __init__(self, config: PhiConfig, layer_idx: int): 640 | super().__init__() 641 | self.self_attn = PHI_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) 642 | self.mlp = PhiMLP(config) 643 | self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 644 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 645 | 646 | def forward( 647 | self, 648 | hidden_states: torch.Tensor, 649 | attention_mask: Optional[torch.Tensor] = None, 650 | position_ids: Optional[torch.LongTensor] = None, 651 | output_attentions: Optional[bool] = False, 652 | use_cache: Optional[bool] = False, 653 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 654 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 655 | """ 656 | Args: 657 | hidden_states (`torch.FloatTensor`): 658 | input to the layer of shape `(batch, seq_len, embed_dim)` 659 | attention_mask (`torch.FloatTensor`, *optional*): attention mask of size 660 | `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 661 | position_ids (`torch.LongTensor` of shape `({0})`, *optional*): 662 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range 663 | `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) 664 | output_attentions (`bool`, *optional*): 665 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 666 | returned tensors for more detail. 667 | use_cache (`bool`, *optional*): 668 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding 669 | (see `past_key_values`). 670 | past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states 671 | """ 672 | 673 | residual = hidden_states 674 | 675 | hidden_states = self.input_layernorm(hidden_states) 676 | 677 | # Self Attention 678 | attn_outputs, self_attn_weights, present_key_value = self.self_attn( 679 | hidden_states=hidden_states, 680 | attention_mask=attention_mask, 681 | position_ids=position_ids, 682 | past_key_value=past_key_value, 683 | output_attentions=output_attentions, 684 | use_cache=use_cache, 685 | ) 686 | attn_outputs = self.resid_dropout(attn_outputs) 687 | 688 | feed_forward_hidden_states = self.resid_dropout(self.mlp(hidden_states)) 689 | hidden_states = attn_outputs + feed_forward_hidden_states + residual 690 | outputs = (hidden_states,) 691 | 692 | if output_attentions: 693 | outputs += (self_attn_weights,) 694 | 695 | if use_cache: 696 | outputs += (present_key_value,) 697 | 698 | return outputs 699 | 700 | 701 | PHI_START_DOCSTRING = r""" 702 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 703 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 704 | etc.) 705 | 706 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 707 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 708 | and behavior. 709 | 710 | Parameters: 711 | config ([`PhiConfig`]): 712 | Model configuration class with all the parameters of the model. Initializing with a config file does not 713 | load the weights associated with the model, only the configuration. Check out the 714 | [`~PreTrainedModel.from_pretrained`] method to load the model weights. 715 | """ 716 | 717 | 718 | @add_start_docstrings( 719 | "The bare Phi Model outputting raw hidden-states without any specific head on top.", 720 | PHI_START_DOCSTRING, 721 | ) 722 | class PhiPreTrainedModel(PreTrainedModel): 723 | config_class = PhiConfig 724 | base_model_prefix = "model" 725 | supports_gradient_checkpointing = True 726 | _no_split_modules = ["PhiDecoderLayer"] 727 | _skip_keys_device_placement = "past_key_values" 728 | _supports_flash_attn_2 = True 729 | _supports_cache_class = True 730 | 731 | def _init_weights(self, module): 732 | std = self.config.initializer_range 733 | if isinstance(module, nn.Linear): 734 | module.weight.data.normal_(mean=0.0, std=std) 735 | if module.bias is not None: 736 | module.bias.data.zero_() 737 | elif isinstance(module, nn.Embedding): 738 | module.weight.data.normal_(mean=0.0, std=std) 739 | if module.padding_idx is not None: 740 | module.weight.data[module.padding_idx].zero_() 741 | 742 | 743 | PHI_INPUTS_DOCSTRING = r""" 744 | Args: 745 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): 746 | Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide 747 | it. 748 | 749 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 750 | [`PreTrainedTokenizer.__call__`] for details. 751 | 752 | [What are input IDs?](../glossary#input-ids) 753 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 754 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 755 | 756 | - 1 for tokens that are **not masked**, 757 | - 0 for tokens that are **masked**. 758 | 759 | [What are attention masks?](../glossary#attention-mask) 760 | 761 | Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and 762 | [`PreTrainedTokenizer.__call__`] for details. 763 | 764 | If `past_key_values` is used, optionally only the last `input_ids` have to be input (see 765 | `past_key_values`). 766 | 767 | If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] 768 | and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more 769 | information on the default strategy. 770 | 771 | - 1 indicates the head is **not masked**, 772 | - 0 indicates the head is **masked**. 773 | position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 774 | Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, 775 | config.n_positions - 1]`. 776 | 777 | [What are position IDs?](../glossary#position-ids) 778 | past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): 779 | Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention 780 | blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` 781 | returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. 782 | 783 | Two formats are allowed: 784 | - a [`~cache_utils.Cache`] instance; 785 | - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of 786 | shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy 787 | cache format. 788 | 789 | The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the 790 | legacy cache format will be returned. 791 | 792 | If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't 793 | have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` 794 | of shape `(batch_size, sequence_length)`. 795 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): 796 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This 797 | is useful if you want more control over how to convert `input_ids` indices into associated vectors than the 798 | model's internal embedding lookup matrix. 799 | use_cache (`bool`, *optional*): 800 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 801 | `past_key_values`). 802 | output_attentions (`bool`, *optional*): 803 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 804 | tensors for more detail. 805 | output_hidden_states (`bool`, *optional*): 806 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 807 | more detail. 808 | return_dict (`bool`, *optional*): 809 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 810 | """ 811 | 812 | 813 | @add_start_docstrings( 814 | "The bare Phi Model outputting raw hidden-states without any specific head on top.", 815 | PHI_START_DOCSTRING, 816 | ) 817 | class PhiModel(PhiPreTrainedModel): 818 | """ 819 | Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`] 820 | 821 | Args: 822 | config: PhiConfig 823 | """ 824 | 825 | def __init__(self, config: PhiConfig): 826 | super().__init__(config) 827 | self.padding_idx = config.pad_token_id 828 | self.vocab_size = config.vocab_size 829 | 830 | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) 831 | self.embed_dropout = nn.Dropout(config.embd_pdrop) 832 | self.layers = nn.ModuleList( 833 | [PhiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] 834 | ) 835 | self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 836 | self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" 837 | 838 | self.gradient_checkpointing = False 839 | # Initialize weights and apply final processing 840 | self.post_init() 841 | 842 | def get_input_embeddings(self): 843 | return self.embed_tokens 844 | 845 | def set_input_embeddings(self, value): 846 | self.embed_tokens = value 847 | 848 | @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) 849 | def forward( 850 | self, 851 | input_ids: torch.LongTensor = None, 852 | attention_mask: Optional[torch.Tensor] = None, 853 | position_ids: Optional[torch.LongTensor] = None, 854 | past_key_values: Optional[List[torch.FloatTensor]] = None, 855 | inputs_embeds: Optional[torch.FloatTensor] = None, 856 | use_cache: Optional[bool] = None, 857 | output_attentions: Optional[bool] = None, 858 | output_hidden_states: Optional[bool] = None, 859 | return_dict: Optional[bool] = None, 860 | ) -> Union[Tuple, BaseModelOutputWithPast]: 861 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 862 | output_hidden_states = ( 863 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 864 | ) 865 | use_cache = use_cache if use_cache is not None else self.config.use_cache 866 | 867 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 868 | 869 | # retrieve input_ids and inputs_embeds 870 | if input_ids is not None and inputs_embeds is not None: 871 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 872 | elif input_ids is not None: 873 | batch_size, seq_length = input_ids.shape[:2] 874 | elif inputs_embeds is not None: 875 | batch_size, seq_length = inputs_embeds.shape[:2] 876 | else: 877 | raise ValueError("You have to specify either input_ids or inputs_embeds") 878 | 879 | past_key_values_length = 0 880 | 881 | if self.gradient_checkpointing and self.training: 882 | if use_cache: 883 | logger.warning_once( 884 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 885 | ) 886 | use_cache = False 887 | 888 | if use_cache: 889 | # dbg: uncomment is original 890 | use_legacy_cache = not isinstance(past_key_values, Cache) 891 | if use_legacy_cache: 892 | past_key_values = DynamicCache.from_legacy_cache(past_key_values) 893 | 894 | past_key_values_length = past_key_values.get_usable_length(seq_length) 895 | 896 | if position_ids is None: 897 | device = input_ids.device if input_ids is not None else inputs_embeds.device 898 | position_ids = torch.arange( 899 | past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device 900 | ) 901 | position_ids = position_ids.unsqueeze(0) 902 | 903 | if inputs_embeds is None: 904 | inputs_embeds = self.embed_tokens(input_ids) 905 | 906 | inputs_embeds = self.embed_dropout(inputs_embeds) 907 | 908 | # Attention mask. 909 | if self._use_flash_attention_2: 910 | # 2d mask is passed through the layers 911 | attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 912 | else: 913 | # 4d mask is passed through the layers 914 | attention_mask = _prepare_4d_causal_attention_mask( 915 | attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 916 | ) 917 | 918 | hidden_states = inputs_embeds 919 | 920 | # decoder layers 921 | all_hidden_states = () if output_hidden_states else None 922 | all_self_attns = () if output_attentions else None 923 | next_decoder_cache = None 924 | 925 | for decoder_layer in self.layers: 926 | if output_hidden_states: 927 | all_hidden_states += (hidden_states,) 928 | 929 | if self.gradient_checkpointing and self.training: 930 | layer_outputs = self._gradient_checkpointing_func( 931 | decoder_layer.__call__, 932 | hidden_states, 933 | attention_mask, 934 | position_ids, 935 | past_key_values, 936 | output_attentions, 937 | ) 938 | else: 939 | layer_outputs = decoder_layer( 940 | hidden_states, 941 | attention_mask=attention_mask, 942 | position_ids=position_ids, 943 | past_key_value=past_key_values, 944 | output_attentions=output_attentions, 945 | use_cache=use_cache, 946 | ) 947 | 948 | hidden_states = layer_outputs[0] 949 | 950 | if use_cache: 951 | next_decoder_cache = layer_outputs[2 if output_attentions else 1] 952 | 953 | if output_attentions: 954 | all_self_attns += (layer_outputs[1],) 955 | 956 | hidden_states = self.final_layernorm(hidden_states) 957 | 958 | # add hidden states from the last decoder layer 959 | if output_hidden_states: 960 | all_hidden_states += (hidden_states,) 961 | 962 | next_cache = None 963 | if use_cache: 964 | next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache 965 | if not return_dict: 966 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 967 | return BaseModelOutputWithPast( 968 | last_hidden_state=hidden_states, 969 | past_key_values=next_cache, 970 | hidden_states=all_hidden_states, 971 | attentions=all_self_attns, 972 | ) 973 | 974 | 975 | class PhiForCausalLM(PhiPreTrainedModel): 976 | _tied_weights_keys = ["lm_head.weight"] 977 | 978 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi,bias=False->bias=True 979 | def __init__(self, config): 980 | super().__init__(config) 981 | self.model = PhiModel(config) 982 | self.vocab_size = config.vocab_size 983 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) 984 | 985 | # Initialize weights and apply final processing 986 | self.post_init() 987 | 988 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings 989 | def get_input_embeddings(self): 990 | return self.model.embed_tokens 991 | 992 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings 993 | def set_input_embeddings(self, value): 994 | self.model.embed_tokens = value 995 | 996 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings 997 | def get_output_embeddings(self): 998 | return self.lm_head 999 | 1000 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings 1001 | def set_output_embeddings(self, new_embeddings): 1002 | self.lm_head = new_embeddings 1003 | 1004 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder 1005 | def set_decoder(self, decoder): 1006 | self.model = decoder 1007 | 1008 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder 1009 | def get_decoder(self): 1010 | return self.model 1011 | 1012 | @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) 1013 | @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) 1014 | def forward( 1015 | self, 1016 | input_ids: torch.LongTensor = None, 1017 | attention_mask: Optional[torch.Tensor] = None, 1018 | position_ids: Optional[torch.LongTensor] = None, 1019 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1020 | inputs_embeds: Optional[torch.FloatTensor] = None, 1021 | labels: Optional[torch.LongTensor] = None, 1022 | use_cache: Optional[bool] = None, 1023 | output_attentions: Optional[bool] = None, 1024 | output_hidden_states: Optional[bool] = None, 1025 | return_dict: Optional[bool] = None, 1026 | ) -> Union[Tuple, CausalLMOutputWithPast]: 1027 | r""" 1028 | Args: 1029 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 1030 | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., 1031 | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored 1032 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. 1033 | 1034 | Returns: 1035 | 1036 | Example: 1037 | 1038 | ```python 1039 | >>> from transformers import AutoTokenizer, PhiForCausalLM 1040 | 1041 | >>> model = PhiForCausalLM.from_pretrained("microsoft/phi-1") 1042 | >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") 1043 | 1044 | >>> prompt = "This is an example script ." 1045 | >>> inputs = tokenizer(prompt, return_tensors="pt") 1046 | 1047 | >>> # Generate 1048 | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) 1049 | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] 1050 | 'This is an example script .\n\n\n\nfrom typing import List\n\ndef find_most_common_letter(words: List[str' 1051 | ```""" 1052 | 1053 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1054 | output_hidden_states = ( 1055 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1056 | ) 1057 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1058 | 1059 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 1060 | outputs = self.model( 1061 | input_ids=input_ids, 1062 | attention_mask=attention_mask, 1063 | position_ids=position_ids, 1064 | past_key_values=past_key_values, 1065 | inputs_embeds=inputs_embeds, 1066 | use_cache=use_cache, 1067 | output_attentions=output_attentions, 1068 | output_hidden_states=output_hidden_states, 1069 | return_dict=return_dict, 1070 | ) 1071 | 1072 | 1073 | # concat the feature back? 1074 | 1075 | hidden_states = outputs[0] 1076 | logits = self.lm_head(hidden_states) 1077 | logits = logits.float() 1078 | 1079 | loss = None 1080 | if labels is not None: 1081 | # Shift so that tokens < n predict n 1082 | shift_logits = logits[..., :-1, :].contiguous() 1083 | shift_labels = labels[..., 1:].contiguous() 1084 | # Flatten the tokens 1085 | loss_fct = CrossEntropyLoss() 1086 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 1087 | shift_labels = shift_labels.view(-1) 1088 | # Enable model parallelism 1089 | shift_labels = shift_labels.to(shift_logits.device) 1090 | loss = loss_fct(shift_logits, shift_labels) 1091 | 1092 | if not return_dict: 1093 | output = (logits,) + outputs[1:] 1094 | return (loss,) + output if loss is not None else output 1095 | 1096 | return CausalLMOutputWithPast( 1097 | loss=loss, 1098 | logits=logits, 1099 | past_key_values=outputs.past_key_values, 1100 | hidden_states=outputs.hidden_states, 1101 | attentions=outputs.attentions, 1102 | ) 1103 | 1104 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation 1105 | def prepare_inputs_for_generation( 1106 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 1107 | ): 1108 | if past_key_values is not None: 1109 | if isinstance(past_key_values, Cache): 1110 | cache_length = past_key_values.get_seq_length() 1111 | past_length = past_key_values.seen_tokens 1112 | max_cache_length = past_key_values.get_max_length() 1113 | else: 1114 | cache_length = past_length = past_key_values[0][0].shape[2] 1115 | max_cache_length = None 1116 | 1117 | # Keep only the unprocessed tokens: 1118 | # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where 1119 | # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as 1120 | # input) 1121 | if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: 1122 | input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] 1123 | # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard 1124 | # input_ids based on the past_length. 1125 | elif past_length < input_ids.shape[1]: 1126 | input_ids = input_ids[:, past_length:] 1127 | # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. 1128 | 1129 | # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. 1130 | if ( 1131 | max_cache_length is not None 1132 | and attention_mask is not None 1133 | and cache_length + input_ids.shape[1] > max_cache_length 1134 | ): 1135 | attention_mask = attention_mask[:, -max_cache_length:] 1136 | 1137 | position_ids = kwargs.get("position_ids", None) 1138 | if attention_mask is not None and position_ids is None: 1139 | # create position_ids on the fly for batch generation 1140 | position_ids = attention_mask.long().cumsum(-1) - 1 1141 | position_ids.masked_fill_(attention_mask == 0, 1) 1142 | if past_key_values: 1143 | position_ids = position_ids[:, -input_ids.shape[1] :] 1144 | 1145 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 1146 | if inputs_embeds is not None and past_key_values is None: 1147 | model_inputs = {"inputs_embeds": inputs_embeds} 1148 | else: 1149 | model_inputs = {"input_ids": input_ids} 1150 | 1151 | model_inputs.update( 1152 | { 1153 | "position_ids": position_ids, 1154 | "past_key_values": past_key_values, 1155 | "use_cache": kwargs.get("use_cache"), 1156 | "attention_mask": attention_mask, 1157 | } 1158 | ) 1159 | return model_inputs 1160 | 1161 | @staticmethod 1162 | # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM._reorder_cache 1163 | def _reorder_cache(past_key_values, beam_idx): 1164 | reordered_past = () 1165 | for layer_past in past_key_values: 1166 | reordered_past += ( 1167 | tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), 1168 | ) 1169 | return reordered_past 1170 | 1171 | 1172 | @add_start_docstrings( 1173 | """ 1174 | The PhiModel with a sequence classification head on top (linear layer). 1175 | 1176 | [`PhiForSequenceClassification`] uses the last token in order to do the classification, as other causal models 1177 | (e.g. GPT-2) do. 1178 | 1179 | Since it does classification on the last token, it requires to know the position of the last token. If a 1180 | `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If 1181 | no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the 1182 | padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in 1183 | each row of the batch). 1184 | """, 1185 | PHI_START_DOCSTRING, 1186 | ) 1187 | # Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->PHI,Llama->Phi with self.transformer->self.model, transformer_outputs->model_outputs 1188 | class PhiForSequenceClassification(PhiPreTrainedModel): 1189 | def __init__(self, config): 1190 | super().__init__(config) 1191 | self.num_labels = config.num_labels 1192 | self.model = PhiModel(config) 1193 | self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) 1194 | 1195 | # Initialize weights and apply final processing 1196 | self.post_init() 1197 | 1198 | def get_input_embeddings(self): 1199 | return self.model.embed_tokens 1200 | 1201 | def set_input_embeddings(self, value): 1202 | self.model.embed_tokens = value 1203 | 1204 | @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) 1205 | def forward( 1206 | self, 1207 | input_ids: torch.LongTensor = None, 1208 | attention_mask: Optional[torch.Tensor] = None, 1209 | position_ids: Optional[torch.LongTensor] = None, 1210 | past_key_values: Optional[List[torch.FloatTensor]] = None, 1211 | inputs_embeds: Optional[torch.FloatTensor] = None, 1212 | labels: Optional[torch.LongTensor] = None, 1213 | use_cache: Optional[bool] = None, 1214 | output_attentions: Optional[bool] = None, 1215 | output_hidden_states: Optional[bool] = None, 1216 | return_dict: Optional[bool] = None, 1217 | ) -> Union[Tuple, SequenceClassifierOutputWithPast]: 1218 | r""" 1219 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1220 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1221 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1222 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1223 | """ 1224 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1225 | 1226 | model_outputs = self.model( 1227 | input_ids, 1228 | attention_mask=attention_mask, 1229 | position_ids=position_ids, 1230 | past_key_values=past_key_values, 1231 | inputs_embeds=inputs_embeds, 1232 | use_cache=use_cache, 1233 | output_attentions=output_attentions, 1234 | output_hidden_states=output_hidden_states, 1235 | return_dict=return_dict, 1236 | ) 1237 | hidden_states = model_outputs[0] 1238 | logits = self.score(hidden_states) 1239 | 1240 | if input_ids is not None: 1241 | batch_size = input_ids.shape[0] 1242 | else: 1243 | batch_size = inputs_embeds.shape[0] 1244 | 1245 | if self.config.pad_token_id is None and batch_size != 1: 1246 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 1247 | if self.config.pad_token_id is None: 1248 | sequence_lengths = -1 1249 | else: 1250 | if input_ids is not None: 1251 | # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility 1252 | sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 1253 | sequence_lengths = sequence_lengths % input_ids.shape[-1] 1254 | sequence_lengths = sequence_lengths.to(logits.device) 1255 | else: 1256 | sequence_lengths = -1 1257 | 1258 | pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 1259 | 1260 | loss = None 1261 | if labels is not None: 1262 | labels = labels.to(logits.device) 1263 | if self.config.problem_type is None: 1264 | if self.num_labels == 1: 1265 | self.config.problem_type = "regression" 1266 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 1267 | self.config.problem_type = "single_label_classification" 1268 | else: 1269 | self.config.problem_type = "multi_label_classification" 1270 | 1271 | if self.config.problem_type == "regression": 1272 | loss_fct = MSELoss() 1273 | if self.num_labels == 1: 1274 | loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) 1275 | else: 1276 | loss = loss_fct(pooled_logits, labels) 1277 | elif self.config.problem_type == "single_label_classification": 1278 | loss_fct = CrossEntropyLoss() 1279 | loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) 1280 | elif self.config.problem_type == "multi_label_classification": 1281 | loss_fct = BCEWithLogitsLoss() 1282 | loss = loss_fct(pooled_logits, labels) 1283 | if not return_dict: 1284 | output = (pooled_logits,) + model_outputs[1:] 1285 | return ((loss,) + output) if loss is not None else output 1286 | 1287 | return SequenceClassifierOutputWithPast( 1288 | loss=loss, 1289 | logits=pooled_logits, 1290 | past_key_values=model_outputs.past_key_values, 1291 | hidden_states=model_outputs.hidden_states, 1292 | attentions=model_outputs.attentions, 1293 | ) 1294 | 1295 | 1296 | @add_start_docstrings( 1297 | """ 1298 | PhiModel with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for 1299 | Named-Entity-Recognition (NER) tasks. 1300 | """, 1301 | PHI_START_DOCSTRING, 1302 | ) 1303 | # Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with MPT->PHI,Mpt->Phi,self.transformer->self.model,transformer_outputs->model_outputs 1304 | class PhiForTokenClassification(PhiPreTrainedModel): 1305 | def __init__(self, config: PhiConfig): 1306 | super().__init__(config) 1307 | self.num_labels = config.num_labels 1308 | 1309 | self.model = PhiModel(config) 1310 | if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: 1311 | classifier_dropout = config.classifier_dropout 1312 | elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: 1313 | classifier_dropout = config.hidden_dropout 1314 | else: 1315 | classifier_dropout = 0.1 1316 | self.dropout = nn.Dropout(classifier_dropout) 1317 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 1318 | 1319 | # Initialize weights and apply final processing 1320 | self.post_init() 1321 | 1322 | @add_start_docstrings_to_model_forward(PHI_INPUTS_DOCSTRING) 1323 | @add_code_sample_docstrings( 1324 | checkpoint=_CHECKPOINT_FOR_DOC, 1325 | output_type=TokenClassifierOutput, 1326 | config_class=_CONFIG_FOR_DOC, 1327 | ) 1328 | def forward( 1329 | self, 1330 | input_ids: Optional[torch.LongTensor] = None, 1331 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 1332 | attention_mask: Optional[torch.Tensor] = None, 1333 | inputs_embeds: Optional[torch.Tensor] = None, 1334 | labels: Optional[torch.Tensor] = None, 1335 | use_cache: Optional[bool] = None, 1336 | output_attentions: Optional[bool] = None, 1337 | output_hidden_states: Optional[bool] = None, 1338 | return_dict: Optional[bool] = None, 1339 | **deprecated_arguments, 1340 | ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: 1341 | r""" 1342 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 1343 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 1344 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 1345 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 1346 | """ 1347 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1348 | 1349 | model_outputs = self.model( 1350 | input_ids, 1351 | past_key_values=past_key_values, 1352 | attention_mask=attention_mask, 1353 | inputs_embeds=inputs_embeds, 1354 | use_cache=use_cache, 1355 | output_attentions=output_attentions, 1356 | output_hidden_states=output_hidden_states, 1357 | return_dict=return_dict, 1358 | ) 1359 | 1360 | hidden_states = model_outputs[0] 1361 | hidden_states = self.dropout(hidden_states) 1362 | logits = self.classifier(hidden_states) 1363 | 1364 | loss = None 1365 | if labels is not None: 1366 | # move labels to correct device to enable model parallelism 1367 | labels = labels.to(logits.device) 1368 | batch_size, seq_length = labels.shape 1369 | loss_fct = CrossEntropyLoss() 1370 | loss = loss_fct( 1371 | logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) 1372 | ) 1373 | 1374 | if not return_dict: 1375 | output = (logits,) + model_outputs[2:] 1376 | return ((loss,) + output) if loss is not None else output 1377 | 1378 | return TokenClassifierOutput( 1379 | loss=loss, 1380 | logits=logits, 1381 | hidden_states=model_outputs.hidden_states, 1382 | attentions=model_outputs.attentions, 1383 | ) 1384 | -------------------------------------------------------------------------------- /allava/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from allava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | import pdb 26 | 27 | class LlavaMetaModel: 28 | 29 | def __init__(self, config): 30 | super(LlavaMetaModel, self).__init__(config) 31 | 32 | if hasattr(config, "mm_vision_tower"): 33 | self.vision_tower = build_vision_tower(config, delay_load=True) 34 | self.mm_projector = build_vision_projector(config) 35 | 36 | def get_vision_tower(self): 37 | vision_tower = getattr(self, 'vision_tower', None) 38 | if type(vision_tower) is list: 39 | vision_tower = vision_tower[0] 40 | return vision_tower 41 | 42 | def initialize_vision_modules(self, model_args, fsdp=None): 43 | vision_tower = model_args.vision_tower 44 | mm_vision_select_layer = model_args.mm_vision_select_layer 45 | mm_vision_select_feature = model_args.mm_vision_select_feature 46 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 47 | 48 | self.config.mm_vision_tower = vision_tower 49 | 50 | if self.get_vision_tower() is None: 51 | vision_tower = build_vision_tower(model_args) 52 | 53 | if fsdp is not None and len(fsdp) > 0: 54 | self.vision_tower = [vision_tower] 55 | else: 56 | self.vision_tower = vision_tower 57 | else: 58 | if fsdp is not None and len(fsdp) > 0: 59 | vision_tower = self.vision_tower[0] 60 | else: 61 | vision_tower = self.vision_tower 62 | vision_tower.load_model() 63 | 64 | self.config.use_mm_proj = True 65 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 66 | self.config.mm_hidden_size = vision_tower.hidden_size 67 | self.config.mm_vision_select_layer = mm_vision_select_layer 68 | self.config.mm_vision_select_feature = mm_vision_select_feature 69 | 70 | if getattr(self, 'mm_projector', None) is None: 71 | self.mm_projector = build_vision_projector(self.config) 72 | else: 73 | # In case it is frozen by LoRA 74 | for p in self.mm_projector.parameters(): 75 | p.requires_grad = True 76 | 77 | if pretrain_mm_mlp_adapter is not None: 78 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 79 | def get_w(weights, keyword): 80 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 81 | 82 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 83 | 84 | 85 | class LlavaMetaForCausalLM(ABC): 86 | 87 | @abstractmethod 88 | def get_model(self): 89 | pass 90 | 91 | @abstractmethod 92 | def get_tokenizer(self): 93 | pass 94 | 95 | def get_vision_tower(self): 96 | return self.get_model().get_vision_tower() 97 | 98 | def encode_images(self, images): 99 | image_features = self.get_model().get_vision_tower()(images) 100 | image_features = self.get_model().mm_projector(image_features) 101 | return image_features 102 | 103 | def prepare_inputs_labels_for_multimodal_new( 104 | self, input_ids: list[torch.tensor], position_ids, attention_mask: list[torch.tensor], past_key_values, labels, images 105 | ): 106 | vision_tower = self.get_vision_tower() 107 | if not self.training: # TODO: check this out!! 108 | # pdb.set_trace() 109 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 110 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: 111 | 112 | if attention_mask is None: 113 | # only happen for qwen at inference 114 | # raise ValueError(f'should not be here except for Qwen!') 115 | return input_ids, None, attention_mask, past_key_values, None, labels 116 | 117 | target_shape = past_key_values[-1][-1].shape[-2] + 1 118 | attention_mask = torch.cat((attention_mask, torch.ones( 119 | (attention_mask.shape[0], target_shape - attention_mask.shape[1]), 120 | dtype=attention_mask.dtype, 121 | device=attention_mask.device 122 | )), dim=1) 123 | position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 124 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 125 | 126 | 127 | # ####################### this block must be optimized! ####################### 128 | # if type(images) is list or images.ndim == 5: 129 | # concat_images = torch.cat([image for image in images], dim=0) 130 | # image_features = self.encode_images(concat_images) 131 | # split_sizes = [image.shape[0] for image in images] 132 | # image_features = torch.split(image_features, split_sizes, dim=0) 133 | # image_features = [x.flatten(0, 1).to(self.device) for x in image_features] 134 | # else: 135 | # image_features = self.encode_images(images).to(self.device) 136 | # ####################### this block must be optimized! ####################### 137 | 138 | # ####################### optimized ####################### 139 | if getattr(self, 'cached_image_features', None) is None: 140 | # this attribute should be cleared in bot.clear_history() 141 | if type(images) is list or images.ndim == 5: 142 | concat_images = torch.cat([image for image in images], dim=0) 143 | image_features = self.encode_images(concat_images) 144 | split_sizes = [image.shape[0] for image in images] 145 | image_features = torch.split(image_features, split_sizes, dim=0) 146 | image_features = [x.flatten(0, 1).to(self.device) for x in image_features] 147 | else: 148 | image_features = self.encode_images(images).to(self.device) 149 | self.cached_image_features = image_features 150 | image_features = self.cached_image_features 151 | # ####################### optimized ####################### 152 | 153 | 154 | # TODO: image start / end is not implemented here to support pretraining. 155 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 156 | raise NotImplementedError 157 | 158 | # Let's just add dummy tensors if they do not exist, 159 | # it is a headache to deal with None all the time. 160 | # But it is not ideal, and if you have a better idea, 161 | # please open an issue / submit a PR, thanks. 162 | _labels = labels 163 | _position_ids = position_ids 164 | _attention_mask = attention_mask 165 | if attention_mask is None: 166 | # attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 167 | attention_mask = [torch.tensor([1]*l).to(input_ids).bool() for l in map(len, [ip for ip in input_ids])] 168 | else: 169 | # attention_mask = attention_mask.bool() 170 | attention_mask = [att.bool() for att in attention_mask] 171 | 172 | # if position_ids is None: 173 | # position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 174 | 175 | if labels is None: 176 | labels = [torch.tensor([IGNORE_INDEX]*l).to(input_ids) for l in map(len, [ip for ip in input_ids])] 177 | # labels = torch.full_like(input_ids, IGNORE_INDEX) 178 | else: 179 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 180 | # remove the padding using attention_mask -- TODO: double check 181 | # pdb.set_trace() 182 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 183 | 184 | new_input_embeds = [] 185 | new_labels = [] 186 | cur_image_idx = 0 187 | for batch_idx, cur_input_ids in enumerate(input_ids): 188 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 189 | if num_images == 0: 190 | 191 | ############### FIXME ############### 192 | if cur_image_idx > len(image_features)-1: 193 | cur_image_idx = len(image_features)-1 194 | print(f'warning: {input_ids}') 195 | ############### FIXME ############### 196 | 197 | cur_image_features = image_features[cur_image_idx] 198 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 199 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 200 | new_input_embeds.append(cur_input_embeds) 201 | new_labels.append(labels[batch_idx]) 202 | cur_image_idx += 1 203 | continue 204 | 205 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 206 | cur_input_ids_noim = [] 207 | cur_labels = labels[batch_idx] 208 | cur_labels_noim = [] 209 | for i in range(len(image_token_indices) - 1): 210 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) 211 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) 212 | split_sizes = [x.shape[0] for x in cur_labels_noim] 213 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 214 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 215 | cur_new_input_embeds = [] 216 | cur_new_labels = [] 217 | 218 | # you have 10 images, but you have 11 placeholders 219 | for i in range(num_images + 1): 220 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 221 | cur_new_labels.append(cur_labels_noim[i]) 222 | if i < num_images: 223 | ############### FIXME ############### 224 | if cur_image_idx > len(image_features)-1: 225 | cur_image_idx = len(image_features)-1 226 | print(f'warning: {input_ids}') 227 | ############### FIXME ############### 228 | 229 | cur_image_features = image_features[cur_image_idx] 230 | cur_image_idx += 1 231 | cur_new_input_embeds.append(cur_image_features) 232 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 233 | 234 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 235 | cur_new_labels = torch.cat(cur_new_labels) 236 | 237 | new_input_embeds.append(cur_new_input_embeds) 238 | new_labels.append(cur_new_labels) 239 | 240 | # Truncate sequences to max length as image embeddings can make the sequence longer 241 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 242 | if tokenizer_model_max_length is not None: 243 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 244 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 245 | 246 | # Combine them 247 | max_len = max(x.shape[0] for x in new_input_embeds) 248 | batch_size = len(new_input_embeds) 249 | 250 | new_input_embeds_padded = [] 251 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 252 | attention_mask = torch.zeros((batch_size, max_len), dtype=torch.bool, device=attention_mask[0].device) 253 | position_ids = torch.zeros((batch_size, max_len), dtype=torch.long, device=attention_mask[0].device) 254 | 255 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 256 | cur_len = cur_new_embed.shape[0] 257 | # print(f'cur_len[{i}]before padding: {cur_len}') 258 | # if i==0: 259 | # print(f"{getattr(self.config, 'tokenizer_padding_side', 'right')} {self.get_tokenizer().padding_side}") 260 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": # checked, this is correct 261 | # if self.get_tokenizer().padding_side == 'left': 262 | new_input_embeds_padded.append(torch.cat(( 263 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 264 | cur_new_embed 265 | ), dim=0)) 266 | if cur_len > 0: 267 | new_labels_padded[i, -cur_len:] = cur_new_labels 268 | attention_mask[i, -cur_len:] = True 269 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 270 | else: 271 | new_input_embeds_padded.append(torch.cat(( 272 | cur_new_embed, 273 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 274 | ), dim=0)) 275 | if cur_len > 0: 276 | new_labels_padded[i, :cur_len] = cur_new_labels 277 | attention_mask[i, :cur_len] = True 278 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 279 | 280 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 281 | 282 | if _labels is None: 283 | new_labels = None 284 | else: 285 | new_labels = new_labels_padded 286 | 287 | if _attention_mask is None: 288 | attention_mask = None 289 | else: 290 | # attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 291 | attention_mask = attention_mask.to(dtype=torch.bool) 292 | 293 | if _position_ids is None: 294 | position_ids = None 295 | 296 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 297 | 298 | def initialize_vision_tokenizer(self, model_args, tokenizer): 299 | if model_args.mm_use_im_patch_token: 300 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 301 | self.resize_token_embeddings(len(tokenizer)) 302 | 303 | if model_args.mm_use_im_start_end: 304 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 305 | self.resize_token_embeddings(len(tokenizer)) 306 | 307 | if num_new_tokens > 0: 308 | input_embeddings = self.get_input_embeddings().weight.data 309 | output_embeddings = self.get_output_embeddings().weight.data 310 | 311 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 312 | dim=0, keepdim=True) 313 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 314 | dim=0, keepdim=True) 315 | 316 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 317 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 318 | 319 | if model_args.tune_mm_mlp_adapter: 320 | for p in self.get_input_embeddings().parameters(): 321 | p.requires_grad = True 322 | for p in self.get_output_embeddings().parameters(): 323 | p.requires_grad = False 324 | 325 | if model_args.pretrain_mm_mlp_adapter: 326 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 327 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 328 | assert num_new_tokens == 2 329 | if input_embeddings.shape == embed_tokens_weight.shape: 330 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 331 | elif embed_tokens_weight.shape[0] == num_new_tokens: 332 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 333 | else: 334 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 335 | elif model_args.mm_use_im_patch_token: 336 | if model_args.tune_mm_mlp_adapter: 337 | for p in self.get_input_embeddings().parameters(): 338 | p.requires_grad = False 339 | for p in self.get_output_embeddings().parameters(): 340 | p.requires_grad = False 341 | -------------------------------------------------------------------------------- /allava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /allava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /allava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | from transformers import AutoModel 6 | 7 | 8 | class CLIPVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 22 | 23 | def load_model(self): 24 | print(f'loading vision model from {self.vision_tower_name}') 25 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 26 | if 'clip' in self.vision_tower_name.lower(): 27 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 28 | 29 | elif 'internvit' in self.vision_tower_name.lower(): 30 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, trust_remote_code=True) 31 | else: 32 | raise ValueError(f'Please implement the loading of vision encoder here') 33 | 34 | self.vision_tower.requires_grad_(False) 35 | 36 | self.is_loaded = True 37 | 38 | def feature_select(self, image_forward_outs): 39 | image_features = image_forward_outs.hidden_states[self.select_layer] 40 | if self.select_feature == 'patch': 41 | image_features = image_features[:, 1:] 42 | elif self.select_feature == 'cls_patch': 43 | image_features = image_features 44 | else: 45 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 46 | return image_features 47 | 48 | @torch.no_grad() 49 | def forward(self, images): 50 | if type(images) is list: 51 | image_features = [] 52 | for image in images: 53 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 54 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 55 | image_features.append(image_feature) 56 | else: 57 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 58 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 59 | 60 | return image_features 61 | 62 | @property 63 | def dummy_feature(self): 64 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 65 | 66 | @property 67 | def dtype(self): 68 | return self.vision_tower.dtype 69 | 70 | @property 71 | def device(self): 72 | return self.vision_tower.device 73 | 74 | @property 75 | def config(self): 76 | if self.is_loaded: 77 | return self.vision_tower.config 78 | else: 79 | return self.cfg_only 80 | 81 | @property 82 | def hidden_size(self): 83 | return self.config.hidden_size 84 | 85 | @property 86 | def num_patches(self): 87 | return (self.config.image_size // self.config.patch_size) ** 2 88 | -------------------------------------------------------------------------------- /allava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /allava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /allava/serve/cli.py: -------------------------------------------------------------------------------- 1 | from allava.constants import IMAGE_TOKEN_INDEX 2 | from allava.model import * 3 | 4 | from transformers import TextStreamer, AutoTokenizer, AutoModelForCausalLM, AutoConfig 5 | import os 6 | import torch 7 | 8 | from PIL import Image 9 | import pdb 10 | 11 | KEYWORDS_IN_PATH = ['allava-3b', 'allava-3b-longer', 'phi'] 12 | 13 | 14 | class Chatbot(): 15 | def __init__(self, config): 16 | self.config = config 17 | 18 | 19 | self.gen_kwargs = { 20 | 'do_sample': False, 21 | 'max_new_tokens': 768, 22 | 'min_new_tokens': 1, 23 | } 24 | 25 | self.device = getattr(config, 'device', 'cuda') 26 | self.init_components() 27 | 28 | self.history = [] 29 | self.images = [] 30 | 31 | # although we support multiple image inputs at inference, this feature is NOT trained. Therefore, inputing multiple images may cause a degraded model performance. 32 | self.max_images_per_round = getattr(config, 'max_images_per_round', 3) 33 | 34 | def init_components(self): 35 | d = self.config.model_dir 36 | 37 | 38 | if any([name in d.lower() for name in KEYWORDS_IN_PATH]): 39 | print(f'loading from {self.config.model_dir}') 40 | model, loading_info = LlavaPhiForCausalLM.from_pretrained(self.config.model_dir, init_vision_encoder_from_ckpt=True, output_loading_info=True, trust_remote_code=True) 41 | 42 | missing_keys = loading_info['missing_keys'] # keys exists in model architecture but does not exist in ckpt 43 | unexpected_keys = loading_info['unexpected_keys'] # keys exists in ckpt but are not loaded by the model 44 | assert missing_keys == [] and unexpected_keys == [] # both should be empty 45 | 46 | self.maxlen = getattr(self.config, 'maxlen', model.config.max_position_embeddings) 47 | tokenizer = AutoTokenizer.from_pretrained(self.config.model_dir, model_max_length=self.maxlen, trust_remote_code=True) 48 | vision_tower = model.get_vision_tower() 49 | 50 | if not vision_tower.is_loaded: 51 | vision_tower.load_model() 52 | vision_tower.to(device=self.device).half() 53 | image_processor = vision_tower.image_processor 54 | eos_token_id = tokenizer.eos_token_id 55 | tokenizer.pad_token_id = tokenizer.eos_token_id 56 | self.gen_kwargs['eos_token_id'] = tokenizer.eos_token_id # new features in transformers 4.37, where you need to explicitly pass the eos_token_id as a param 57 | self.gen_kwargs['pad_token_id'] = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id 58 | print(f'setting eos_token_id to {eos_token_id}') 59 | 60 | 61 | else: 62 | print(f'please load your model properly.') 63 | raise NotImplementedError 64 | 65 | model.eval() 66 | self.model = model.half().to(self.device) 67 | self.tokenizer = tokenizer 68 | self.processor = image_processor 69 | 70 | 71 | def clear_history(self,): 72 | self.images = [] 73 | self.history = [] 74 | self.model.cached_image_features = None 75 | 76 | 77 | # copied from llava 78 | def tokenizer_image_token(self, prompt, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 79 | prompt_chunks = [self.tokenizer(chunk, add_special_tokens=False).input_ids for chunk in prompt.split('')] 80 | 81 | def insert_separator(X, sep): 82 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 83 | 84 | input_ids = [] 85 | offset = 0 86 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == self.tokenizer.bos_token_id: 87 | offset = 1 88 | input_ids.append(prompt_chunks[0][0]) 89 | 90 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 91 | input_ids.extend(x[offset:]) 92 | 93 | if return_tensors is not None: 94 | if return_tensors == 'pt': 95 | return torch.tensor(input_ids, dtype=torch.long) 96 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 97 | return input_ids 98 | 99 | 100 | def preprocess(self, data: list, return_tensors='pt'): 101 | ''' 102 | [ 103 | { 104 | 'from': 'human', 105 | 'value': xxx, 106 | }, 107 | { 108 | 'from': 'gpt', 109 | 'value': xxx 110 | } 111 | ] 112 | ''' 113 | # needs update 114 | if not isinstance(data, list): 115 | raise ValueError('must be a list') 116 | 117 | d = self.config.model_dir 118 | 119 | # this is per model (tokenizer) 120 | if any([name in d.lower() for name in KEYWORDS_IN_PATH]): 121 | return self.preprocess_allava(data, return_tensors=return_tensors) 122 | 123 | elif d in ['/path/to/llava-v1.5-13b']: 124 | return self.preprocess_vicuna_v1(data, return_tensors=return_tensors) 125 | 126 | else: 127 | raise NotImplementedError 128 | 129 | 130 | 131 | def preprocess_vicuna_v1(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations 132 | input_ids = None 133 | for ind, conv in enumerate(convs): 134 | if ind % 2 == 0: # human 135 | h = conv['value'].strip() 136 | h = f"USER: {h} " 137 | cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors) 138 | 139 | if input_ids is None: 140 | input_ids = cur_input_ids 141 | else: 142 | input_ids = torch.cat([input_ids, cur_input_ids]) 143 | 144 | else: # gpt 145 | g = conv['value'] 146 | if g is not None: 147 | cur_input_ids = self.tokenizer(f"ASSISTANT: {g}", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] 148 | input_ids = torch.cat([input_ids, cur_input_ids]) 149 | else: 150 | cur_input_ids = self.tokenizer(f"ASSISTANT:", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] 151 | input_ids = torch.cat([input_ids, cur_input_ids]) 152 | 153 | 154 | return input_ids 155 | 156 | def preprocess_allava(self, convs: list, return_tensors) -> list: # tokenize and concat the coversations 157 | input_ids = None 158 | 159 | for ind, conv in enumerate(convs): 160 | if ind % 2 == 0: # human 161 | h = conv['value'].strip() 162 | h = f"[INST] {h} [/INST] " 163 | cur_input_ids = self.tokenizer_image_token(prompt=h, return_tensors=return_tensors) 164 | 165 | if input_ids is None: 166 | input_ids = cur_input_ids 167 | else: 168 | input_ids = torch.cat([input_ids, cur_input_ids]) 169 | 170 | else: # gpt 171 | g = conv['value'] 172 | if g is not None: 173 | cur_input_ids = self.tokenizer(f"{g}{self.tokenizer.eos_token}", add_special_tokens= False, max_length=self.maxlen, truncation=True, return_tensors='pt').input_ids[0] 174 | input_ids = torch.cat([input_ids, cur_input_ids]) 175 | 176 | return input_ids 177 | 178 | 179 | def input_moderation(self, t: str): 180 | 181 | blacklist = ['', '', ''] 182 | for b in blacklist: 183 | t = t.replace(b, '') 184 | return t 185 | 186 | def insert_image_placeholder(self, t, num_images, placeholder='', sep='\n'): 187 | for _ in range(num_images): 188 | t = f"{placeholder}{sep}" + t 189 | 190 | return t 191 | 192 | def get_conv(self, text): 193 | ret = [] 194 | if self.history is None: 195 | self.history = [] 196 | 197 | for conv in self.history: 198 | ret.append({'from': 'human', 'value': conv[0]}) 199 | ret.append({'from': 'gpt', 'value': conv[1]}) 200 | 201 | ret.append({'from': 'human', 'value': text}) 202 | ret.append({'from': 'gpt', 'value': None}) 203 | return ret 204 | 205 | # copied from llava 206 | def get_image_tensors(self, images): 207 | list_image_tensors = [] 208 | crop_size = self.processor.crop_size 209 | processor = self.processor 210 | for fp in images: 211 | if fp is None: # None is used as a placeholder 212 | list_image_tensors.append(torch.zeros(3, crop_size['height'], crop_size['width']).to(self.device)) 213 | continue 214 | elif isinstance(fp, str): 215 | image = Image.open(fp).convert('RGB') 216 | elif isinstance(fp, Image.Image): 217 | image = fp # already an image 218 | else: 219 | raise TypeError(f'Unsupported type {type(fp)}') 220 | 221 | # this is the way of preprocessing images we used in training, so we impose it here 222 | if True: 223 | # self.data_args.image_aspect_ratio == 'pad' 224 | def expand2square(pil_img, background_color): 225 | width, height = pil_img.size 226 | if pil_img.mode == 'L': 227 | pil_img = pil_img.convert('RGB') 228 | 229 | if width == height: 230 | return pil_img 231 | elif width > height: 232 | result = Image.new(pil_img.mode, (width, width), background_color) 233 | result.paste(pil_img, (0, (width - height) // 2)) 234 | return result 235 | else: 236 | result = Image.new(pil_img.mode, (height, height), background_color) 237 | result.paste(pil_img, ((height - width) // 2, 0)) 238 | return result 239 | 240 | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) 241 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 242 | else: 243 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] # a tensor 244 | list_image_tensors.append(image.to(self.device)) 245 | return list_image_tensors 246 | 247 | 248 | def chat(self, text: str, images: list[str]=None, ): 249 | ''' 250 | images: list[str], images for the *current* round 251 | text: text input for the *current* round 252 | ''' 253 | 254 | ############################ 255 | # 1. preprocess texts 256 | ############################ 257 | text = self.input_moderation(text) 258 | if text == '': 259 | return 'Please type in something' 260 | 261 | if isinstance(images, str) or isinstance(images, Image.Image): 262 | images = [images] 263 | 264 | 265 | ############################ 266 | # 2. preprocess images 267 | ############################ 268 | valid_images = [] 269 | if images is None: 270 | images = [None] 271 | 272 | for img in images: 273 | try: 274 | if isinstance(img, str): 275 | Image.open(img).convert('RGB') # make sure that the path exists 276 | valid_images.append(img) 277 | except: 278 | continue 279 | 280 | images = valid_images 281 | 282 | if images == [] and self.images == []: 283 | self.images = [None] 284 | 285 | self.images.extend(images) 286 | 287 | assert len(images) < self.max_images_per_round, f'at most {self.max_images_per_round} images' 288 | 289 | ############################ 290 | # 3. collate conv 291 | ############################ 292 | 293 | # insert 294 | text = self.insert_image_placeholder(text, len(images) if None not in images else 0) 295 | 296 | # collate strings into conv 297 | conv = self.get_conv(text) 298 | 299 | # make input ids 300 | input_ids = self.preprocess(conv, return_tensors='pt').unsqueeze(0).to(self.device) 301 | 302 | list_image_tensors = self.get_image_tensors(self.images) 303 | image_tensors = torch.stack(list_image_tensors) 304 | 305 | try: 306 | dtype = torch.bfloat16 307 | # if your hardware does not support bf16, the following line raises an error 308 | torch.tensor(1, dtype=dtype).cuda() 309 | except: 310 | # default using fp16 311 | dtype = torch.float16 312 | 313 | ############################ 314 | # 4. generate response 315 | ############################ 316 | with torch.autocast(device_type='cuda', dtype=dtype): 317 | output_ids = self.model.generate( 318 | inputs=input_ids, 319 | images=image_tensors, 320 | use_cache=getattr(self, 'use_cache', True), 321 | **self.gen_kwargs) 322 | 323 | answer = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:], skip_special_tokens=True).strip() 324 | 325 | # update history 326 | self.history.append([text, answer]) 327 | return answer 328 | 329 | 330 | 331 | 332 | if __name__ =="__main__": 333 | 334 | import argparse 335 | parser = argparse.ArgumentParser(description='Args of Data Preprocess') 336 | 337 | # Model Args 338 | parser.add_argument('--model_dir', default='', type=str) 339 | parser.add_argument('--max_images_per_round', default=4, type=int) 340 | parser.add_argument('--maxlen', default=3500, type=int) 341 | parser.add_argument('--device', default='cuda:0', type=str) 342 | args = parser.parse_args() 343 | 344 | bot = Chatbot(args) 345 | 346 | image_prompt = 'image pth, (split by "," for multiple images): ' 347 | 348 | images = input(image_prompt) 349 | images = [i.strip() for i in images.split(',')] 350 | while True: 351 | text = input('USER ("clear" to clear history, "q" to exit): ') 352 | if text.lower() in ['q', 'quit']: 353 | exit() 354 | if text.lower() == 'clear': 355 | bot.clear_history() 356 | images = input(image_prompt) 357 | images = [i.strip() for i in images.split(',')] 358 | continue 359 | answer = bot.chat(images=images, text=text) 360 | images = None # already in the history 361 | print() 362 | print(f'GPT: {answer}') 363 | print() -------------------------------------------------------------------------------- /allava/serve/run_inference_allava-phi2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from transformers import AutoTokenizer 3 | import torch 4 | import pdb 5 | 6 | dir = "FreedomIntelligence/ALLaVA-Phi2-2_7B" 7 | 8 | device = 'cuda' 9 | model = AutoModelForCausalLM.from_pretrained(dir, trust_remote_code=True, device_map=device, torch_dtype=torch.bfloat16) 10 | tokenizer = AutoTokenizer.from_pretrained(dir) 11 | model.tokenizer = tokenizer 12 | 13 | gen_kwargs = { 14 | 'min_new_tokens': 20, 15 | 'max_new_tokens': 100, 16 | 'do_sample': False, 17 | 'eos_token_id': tokenizer.eos_token_id, 18 | } 19 | 20 | ################################################################################# 21 | # first round 22 | ################################################################################# 23 | response, history = model.chat( 24 | texts='What is in the image?', 25 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], 26 | return_history=True, 27 | **gen_kwargs 28 | ) 29 | print('response:') 30 | print(response) 31 | print() 32 | print('history:') 33 | print(history) 34 | 35 | ''' 36 | response: 37 | The image contains a large, stylized "HI!" in a bright pink color with yellow outlines. The "HI!" is placed within a speech bubble shape. 38 | 39 | history: 40 | [['What is in the image?', 'The image contains a large, stylized "HI!" in a bright pink color with yellow outlines. The "HI!" is placed within a speech bubble shape.']] 41 | ''' 42 | 43 | 44 | ################################################################################# 45 | # second round 46 | ################################################################################# 47 | response, history = model.chat( 48 | texts='Are you sure?', 49 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], # images need to be passed again in multi-round conversations 50 | history=history, 51 | return_history=True, 52 | **gen_kwargs 53 | ) 54 | 55 | print('response:') 56 | print(response) 57 | print() 58 | print('history:') 59 | print(history) 60 | 61 | ''' 62 | response: 63 | Yes, I'm certain. The image is a graphic representation of the word "HI!" in a speech bubble. 64 | 65 | history: 66 | [['What is in the image?', 'The image contains a large, stylized "HI!" in a bright pink color with yellow outlines. The "HI!" is placed within a speech bubble shape.'], ['Are you sure?', 'Yes, I\'m certain. The image is a graphic representation of the word "HI!" in a speech bubble.']] 67 | ''' -------------------------------------------------------------------------------- /allava/serve/run_inference_allava-phi3.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from transformers import AutoTokenizer 3 | import torch 4 | import pdb 5 | 6 | dir = "FreedomIntelligence/ALLaVA-Phi3-mini-128k" 7 | 8 | device = 'cuda' 9 | model = AutoModelForCausalLM.from_pretrained(dir, trust_remote_code=True, device_map=device, torch_dtype=torch.bfloat16) 10 | tokenizer = AutoTokenizer.from_pretrained(dir) 11 | model.tokenizer = tokenizer 12 | 13 | gen_kwargs = { 14 | 'min_new_tokens': 20, 15 | 'max_new_tokens': 100, 16 | 'do_sample': False, 17 | # eos_token_id is not needed for this model 18 | } 19 | 20 | ################################################################################# 21 | # first round 22 | ################################################################################# 23 | response, history = model.chat( 24 | texts='What is in the image?', 25 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], 26 | return_history=True, 27 | **gen_kwargs 28 | ) 29 | print('response:') 30 | print(response) 31 | print() 32 | print('history:') 33 | print(history) 34 | 35 | ''' 36 | response: 37 | - There is a speech bubble in the image. 38 | - The speech bubble contains the word "HI!" in bold, yellow letters. 39 | 40 | history: 41 | [['What is in the image?', '- There is a speech bubble in the image.\n- The speech bubble contains the word "HI!" in bold, yellow letters.']] 42 | ''' 43 | 44 | 45 | ################################################################################# 46 | # second round 47 | ################################################################################# 48 | response, history = model.chat( 49 | texts='Are you sure?', 50 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], # images need to be passed again in multi-round conversations 51 | history=history, 52 | return_history=True, 53 | **gen_kwargs 54 | ) 55 | 56 | print('response:') 57 | print(response) 58 | print() 59 | print('history:') 60 | print(history) 61 | 62 | ''' 63 | response: 64 | - Yes, I am certain. The image prominently features a speech bubble with the word "HI!" inside it. 65 | 66 | history: 67 | [['What is in the image?', '- There is a speech bubble in the image.\n- The speech bubble contains the word "HI!" in bold, yellow letters.'], ['Are you sure?', '- Yes, I am certain. The image prominently features a speech bubble with the word "HI!" inside it.']] 68 | ''' -------------------------------------------------------------------------------- /allava/serve/run_inference_allava-stablelm2.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM 2 | from transformers import AutoTokenizer 3 | import torch 4 | import pdb 5 | 6 | 7 | dir = "FreedomIntelligence/ALLaVA-StableLM2-1_6B" 8 | 9 | device = 'cuda' 10 | model = AutoModelForCausalLM.from_pretrained(dir, trust_remote_code=True, device_map=device, torch_dtype=torch.bfloat16) 11 | tokenizer = AutoTokenizer.from_pretrained(dir, trust_remote_code=True) 12 | model.tokenizer = tokenizer 13 | 14 | gen_kwargs = { 15 | 'min_new_tokens': 20, 16 | 'max_new_tokens': 100, 17 | 'do_sample': False, 18 | 'eos_token_id': tokenizer.eos_token_id, 19 | } 20 | 21 | ################################################################################# 22 | # first round 23 | ################################################################################# 24 | response, history = model.chat( 25 | texts='What is in the image?', 26 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], 27 | return_history=True, 28 | **gen_kwargs 29 | ) 30 | print('response:') 31 | print(response) 32 | print() 33 | print('history:') 34 | print(history) 35 | 36 | ''' 37 | response: 38 | The image contains a graphic design of a speech bubble with the word "HI!" written inside it. The speech bubble is colored in pink and has yellow outlines. 39 | 40 | history: 41 | [['What is in the image?', 'The image contains a graphic design of a speech bubble with the word "HI!" written inside it. The speech bubble is colored in pink and has yellow outlines.']] 42 | ''' 43 | 44 | 45 | ################################################################################# 46 | # second round 47 | ################################################################################# 48 | response, history = model.chat( 49 | texts='Are you sure?', 50 | images=['https://cdn-icons-png.flaticon.com/256/6028/6028690.png'], # images need to be passed again in multi-round conversations 51 | history=history, 52 | return_history=True, 53 | **gen_kwargs 54 | ) 55 | 56 | print('response:') 57 | print(response) 58 | print() 59 | print('history:') 60 | print(history) 61 | 62 | ''' 63 | response: 64 | Yes, I am certain. The image displays a graphic design of a speech bubble with the word "HI!" written inside it. The speech bubble is colored in pink and has yellow outlines. 65 | 66 | history: 67 | [['What is in the image?', 'The image contains a graphic design of a speech bubble with the word "HI!" written inside it. The speech bubble is colored in pink and has yellow outlines.'], ['Are you sure?', 'Yes, I am certain. The image displays a graphic design of a speech bubble with the word "HI!" written inside it. The speech bubble is colored in pink and has yellow outlines.']] 68 | ''' -------------------------------------------------------------------------------- /assets/llavas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/ALLaVA/896b77f0ac48c95031cf31ce1b95ccaa380e68bf/assets/llavas.png -------------------------------------------------------------------------------- /assets/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/ALLaVA/896b77f0ac48c95031cf31ce1b95ccaa380e68bf/assets/pipeline.jpg -------------------------------------------------------------------------------- /assets/pipeline.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/ALLaVA/896b77f0ac48c95031cf31ce1b95ccaa380e68bf/assets/pipeline.pdf -------------------------------------------------------------------------------- /assets/training_datasets_by_stage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FreedomIntelligence/ALLaVA/896b77f0ac48c95031cf31ce1b95ccaa380e68bf/assets/training_datasets_by_stage.jpg -------------------------------------------------------------------------------- /download/download_laion.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | laion_root="allava_laion" 4 | 5 | mkdir $laion_root 6 | cd $laion_root 7 | 8 | 9 | # 1. download annotation files 10 | ## 1.1 caption 11 | wget -c -O ALLaVA-Caption-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Caption-LAION-4V.json?download=true 12 | 13 | ## 1.2 instruction 14 | wget -c -O ALLaVA-Instruct-LAION-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/ALLaVA-Instruct-LAION-4V.json?download=true 15 | 16 | 17 | # 2. download and upzip images 18 | mkdir image_chunks 19 | 20 | ## 2.1 download 21 | for ((i=0; i<10; i++)) 22 | do 23 | wget -c -O image_chunks/images_$i.zip https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_laion/image_chunks/images_$i.zip?download=true & 24 | done 25 | 26 | ## 2.2 unzip 27 | for ((i=0; i<10; i++)) 28 | do 29 | unzip -j image_chunks/images_$i.zip -d images/ & # wait patiently, it takes a while... 30 | done 31 | 32 | 33 | 34 | 35 | # for ((i=1; i<3; i++)) 36 | # do 37 | # unzip -j i$i.zip -d i/ & # wait patiently, it takes a while... 38 | # done 39 | -------------------------------------------------------------------------------- /download/download_text.sh: -------------------------------------------------------------------------------- 1 | 2 | text_root="allava_text" 3 | 4 | mkdir $text_root 5 | cd $text_root 6 | 7 | wget -c -O Evol-Instruct-GPT4-Turbo-143K.json "https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_text/Evol-Instruct-GPT4-Turbo-143K.json?download=true" 8 | -------------------------------------------------------------------------------- /download/download_vflan.sh: -------------------------------------------------------------------------------- 1 | vflan_root="allava_vflan" 2 | 3 | mkdir $vflan_root 4 | cd $vflan_root 5 | 6 | # 1. download annotation files 7 | ## 1.1 caption 8 | wget -c -O ALLaVA-Caption-VFLAN-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_vflan/ALLaVA-Caption-VFLAN-4V.json?download=true 9 | 10 | ## 1.2 instruction 11 | wget -c -O ALLaVA-Instruct-VFLAN-4V.json https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/allava_vflan/ALLaVA-Instruct-VFLAN-4V.json?download=true 12 | 13 | 14 | # 2. download and upzip images 15 | mkdir images 16 | cd images 17 | 18 | wget -c -O "image_191-task_1k.zip" "https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k/resolve/main/image_191-task_1k.zip?download=true" 19 | 20 | unzip image_191-task_1k.zip 21 | -------------------------------------------------------------------------------- /download/legacy/laion/download_images_from_url.py: -------------------------------------------------------------------------------- 1 | import json 2 | import requests 3 | import os 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from tqdm import tqdm 6 | from multiprocessing import Pool 7 | 8 | 9 | ############### INPUT and OUTPUT path ############### 10 | hf_laion_caption_path = '/path/to/ALLaVA-Caption-LAION-4V.json' 11 | laion_caption_output_path = '/path/to/ALLaVA-Caption-LAION-4V_with_image.json' 12 | 13 | hf_laion_inst_path = '/path/to/ALLaVA-Instruct-LAION-4V.json' # 14 | laion_inst_output_path = '/path/to/ALLaVA-Instruct-LAION-4V_with_image.json' 15 | 16 | image_dir = '/path/to/image_dir' 17 | ############### INPUT and OUTPUT path ############### 18 | 19 | 20 | 21 | 22 | 23 | def download_single_image(line): 24 | try: 25 | url = line['url'] 26 | image_path = os.path.join(args.image_dir, f'allava_laion_{line["id"].split("_")[-1]}') 27 | # allava_laion_0, allava_laion_1, allava_laion_2, ... 28 | # note that they are saved as binary files. 29 | # each file can be loaded with Image.open() 30 | 31 | if os.path.exists(image_path): 32 | line['image'] = image_path 33 | return line 34 | 35 | response = requests.get(url, timeout=60) 36 | 37 | if response.status_code == 200: 38 | # save as a binary file 39 | with open(image_path, 'wb') as file: 40 | file.write(response.content) 41 | line['image'] = image_path 42 | return line 43 | else: 44 | return None 45 | 46 | except Exception as e: 47 | # remove the binary image file 48 | if os.path.exists(image_path): 49 | os.remove(image_path) 50 | return None 51 | 52 | 53 | 54 | if __name__ == '__main__': 55 | import argparse 56 | parser = argparse.ArgumentParser() 57 | 58 | parser.add_argument('--image_dir', default='', type=str) 59 | 60 | parser.add_argument('--hf_laion_caption_path', required=True) 61 | parser.add_argument('--laion_caption_output_path', required=True) 62 | 63 | parser.add_argument('--hf_laion_inst_path', required=True) 64 | parser.add_argument('--laion_inst_output_path', required=True) 65 | 66 | parser.add_argument('--num_processes', default=200, type=int) 67 | 68 | args = parser.parse_args() 69 | 70 | os.makedirs(args.image_dir, exist_ok=True) 71 | 72 | 73 | for input_path, output_path in ( 74 | [args.hf_laion_caption_path, args.laion_caption_output_path], # this step takes long time to run. The code supports continual download so you can interupt and rerun at anytime. 75 | [args.hf_laion_inst_path, args.laion_inst_output_path] # this step takes little time to run since it shares the same set of images with caption 76 | ): 77 | 78 | with open(input_path) as f: 79 | data = json.load(f) 80 | 81 | with Pool(processes=args.num_processes) as pool: 82 | results = list(tqdm(pool.imap_unordered(download_single_image, data), total=len(data))) 83 | 84 | # filter None 85 | results = [da for da in results if da is not None] 86 | 87 | print('downloaded image:', len(results)) 88 | 89 | # save 90 | os.path.makedirs(os.path.dirname(output_path), exist_ok=True) 91 | with open(output_path, 'w') as fw: 92 | json.dump(results, fw, ensure_ascii=False, indent=2) 93 | -------------------------------------------------------------------------------- /download/legacy/laion/download_laion_from_url.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | ################################################################################ 4 | hf_cap_ann_path="" # path to store hf caption annotation file 5 | hf_inst_ann_path="" # path to store hf instruction annotation file 6 | 7 | image_dir="" # directory to store images 8 | cap_ann_with_image_path="" # path to store new *caption* annotation files with local image path 9 | inst_ann_with_image_path="" # path to store new *instruction* annotation files with local image path 10 | ################################################################################ 11 | 12 | 13 | # 0. check file path 14 | if [ "$hf_cap_ann_path" = "$cap_ann_with_image_path" ]; then 15 | echo "Input and output path are equal, exiting..." 16 | return 1 2>/dev/null 17 | fi 18 | 19 | if [ "$hf_inst_ann_path" = "$inst_ann_with_image_path" ]; then 20 | echo "Input and output path are equal, exiting..." 21 | return 1 2>/dev/null 22 | fi 23 | 24 | 25 | # 1. download annotation files from huggingface 26 | ## 1.1 caption 27 | wget -c -O $hf_cap_ann_path https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/ALLaVA-Caption-LAION-4V.json?download=true 28 | 29 | ## 1.2 instruction 30 | wget -c -O $hf_inst_ann_path https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V/resolve/main/ALLaVA-Instruct-LAION-4V.json?download=true 31 | 32 | 33 | # 2. download images from url 34 | python ./download/laion/download_images_from_url.py \ 35 | --hf_laion_caption_path $hf_cap_ann_path \ 36 | --laion_caption_output_path $cap_ann_with_image_path \ 37 | --hf_laion_inst_path $hf_inst_ann_path \ 38 | --laion_inst_output_path $inst_ann_with_image_path \ 39 | --image_dir $image_dir \ 40 | --num_processes 200 41 | 42 | -------------------------------------------------------------------------------- /prompts/instructions_for_captions.txt: -------------------------------------------------------------------------------- 1 | Write a comprehensive caption for the image provided. 2 | Can you help me understand the image by providing a detailed caption? 3 | Please provide a vivid description of the image. 4 | Elaborate on the details of the image provided. 5 | Could you please interpret the image and write a detailed caption? 6 | Please depict the image in words. 7 | How would you describe the image to someone who cannot see it? 8 | Please enlighten me with a detailed description of the image. 9 | Can you transform the visual elements of the image into words? 10 | Please provide a detailed written representation of the image. 11 | Could you please transcribe the image into a descriptive paragraph? 12 | Please illustrate the image through your words. 13 | Please provide a detailed narrative of the image. 14 | Could you please express the image in a descriptive format? 15 | Please convert the visual information in the image into a detailed written explanation. -------------------------------------------------------------------------------- /prompts/prompt_for_laion.txt: -------------------------------------------------------------------------------- 1 | ### You are an excellent image describer and questioner 2 | ### You have three tasks in total 3 | #### Your first task is to describe the given image as detailed as possible 4 | #### Your second task is to ask a complex question that requires close inspection of the image and strong reasoning ability to answer, you should ask FIVE candidate questions in different aspects and diverse ways, then RANDOMLY choose one of them to answer 5 | #### Your third task is to answer the question you raised solely based on the given image 6 | ### When you ask questions, try to find the most valuable information in the picture to ask about, and ask a question that is relevant to that information 7 | ### When you ask questions, do not involve violence, advertisement, possible invasion of privacy, or questions that may cause discomfort 8 | ### Do not mention anything from the prompt in your response 9 | ### You will follow the instructions to the best of your ability 10 | ### Your response should follow the following format 11 | 12 | {description} 13 | 14 | 15 | {candidate questions} 16 | 17 | 18 | {question} 19 | 20 | 21 | {answer} 22 | 23 | -------------------------------------------------------------------------------- /prompts/prompt_for_vflan.txt: -------------------------------------------------------------------------------- 1 | You are an excellent image describer. 2 | 3 | Your task is to first describe an image and then answer a question. 4 | 5 | Your description should include details about the main subjects, background elements, colors, and any notable features. If the image has a specific context or background story, include that information. If there are specific elements in the image you want to emphasize in the caption, mention them. 6 | 7 | Your answer should provide relevant information to the question and demonstrate the process of solving the question. 8 | 9 | Both your description and answer should be professional, insightful, helpful, objective, unbiased. 10 | 11 | For scenarios where bias has been traditionally an issue, make sure that key traits such as gender and race are specified and in an unbiased way in the description -- for example, prompts that contain references to specific occupations. 12 | 13 | If the question tries to induce you to produce something against ethical rules, such as leaking personal information or making discriminative judgements on underrepresented groups, you must point out the inappropriate intent and refuse to answer the question. 14 | 15 | Here is the question: 16 | ```question 17 | {question} 18 | ``` 19 | 20 | Your output should follow the format below: 21 | 22 | 23 | {description} 24 | 25 | 26 | 27 | {detailed_answer} 28 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.37.0 2 | torch==2.1.1 3 | torchvision==0.16.1+cu118 4 | -------------------------------------------------------------------------------- /scripts/zip_images.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | 4 | from tqdm import tqdm 5 | 6 | 7 | from PIL import Image 8 | import os, pdb 9 | 10 | import multiprocessing as mp 11 | 12 | 13 | 14 | 15 | ''' 16 | id_set = set() 17 | id2line = {} 18 | 19 | for line in tqdm(lines): 20 | image = line['image'] 21 | id = image.split('/')[-1] 22 | if id in id_set: 23 | pdb.set_trace() 24 | id_set.add(id) 25 | id2line[id] = line 26 | 27 | print(len(lines)) 28 | 29 | print(len(id_set)) 30 | pdb.set_trace() 31 | ''' 32 | 33 | ''' 34 | allava_laion/ 35 | images/ 36 | cap.json 37 | inst.json 38 | 39 | ''' 40 | 41 | output_dir = '/mntcephfs/data/med/guimingchen/workspaces/vllm/upload/ALLaVA/dataset_v2' 42 | 43 | 44 | 45 | 46 | def process_image(line): 47 | ''' 48 | Function to process a single image 49 | ''' 50 | # line['image'] = line['image'].replace('/mntcephfs/data/med/zhanghongbo/MOSS/cjy/cjy_data', '/wangbenyou/guimingchen/datasets/laion') 51 | img = Image.open(line['image']) 52 | img_format = img.format 53 | int(line['id']) 54 | img_name = line['id'] + "." + img_format.lower() 55 | 56 | dst = os.path.join(output_dir, 'allava_laion/images', img_name) 57 | 58 | if not os.path.exists(dst): 59 | os.symlink(line['image'], dst) 60 | return dst 61 | 62 | 63 | def process_images(): 64 | ''' 65 | create a soft link for each image 66 | ''' 67 | 68 | with open('/mntcephfs/data/med/shunian/vlm/data/huggingface_version/laion_v3.json') as f: 69 | lines = json.load(f)[:] 70 | 71 | pdb.set_trace() 72 | 73 | # # create a dict mapping each image path to an int. The int will be the released id. 74 | # with open('/wangbenyou/guimingchen/workspaces/vllm/upload/hf/dataset_v2/path2id.json') as f: 75 | # global path2id 76 | # path2id = json.load(f) 77 | 78 | # number of processes to create 79 | process_num = mp.cpu_count()-2 80 | print(f'using {process_num}') 81 | 82 | with mp.Pool(process_num) as pool: 83 | # this uses tqdm for a progress bar 84 | list(tqdm(pool.imap(process_image, lines), total = len(lines))) 85 | 86 | process_images() --------------------------------------------------------------------------------