├── LICENSE ├── README.md ├── configs ├── test_alignment_prompt.json └── test_faithfulness_prompt.json ├── evalalign ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── test_alignment.py │ └── test_faithfulness.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── __init__.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ └── llava_mpt.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ ├── s2wrapper │ │ ├── __init__.py │ │ ├── core.py │ │ └── utils.py │ └── utils.py └── utils.py ├── pyproject.toml └── scripts ├── inference_alignment.sh └── inference_faithfulness.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # EvalAlign 3 | 4 | # EVALALIGN: Supervised Fine-Tuning Multimodal LLMs with Human-Aligned Data for Evaluating Text-to-Image Models 5 | 6 | [![arXiv](https://img.shields.io/badge/arXiv-2406.16562-b31b1b.svg)](https://arxiv.org/abs/2406.16562) 7 | [![Project Page](https://img.shields.io/badge/Project-Website-blue)](https://sais-fuxi.github.io/projects/evalalign/) 8 | [![Hugging Face weight](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Weights-yellow)](https://huggingface.co/Fudan-FUXI/evalalign-v1.0-13b) 9 | [![HF Datasets](https://img.shields.io/static/v1?label=Datasets&message=datasets&color=green)](https://huggingface.co/datasets/Fudan-FUXI/EvalAlign-datasets) 10 | 11 | ## Contents 12 | - [Install](#install) 13 | - [EvalAlign Dataset](#evalalign-dataset) 14 | - [EvalAlign Weights](#evalalign-weights) 15 | - [Evaluation](#evaluation) 16 | 17 | ## Install 18 | 1. Clone this repository and navigate to EvalAlign folder 19 | ```bash 20 | git clone https://github.com/SAIS-FUXI/EvalAlign.git 21 | cd EvalAlign 22 | ``` 23 | 2. Install Package 24 | ```Shell 25 | conda create -n evalalign python=3.10 -y 26 | conda activate evalalign 27 | pip install --upgrade pip 28 | pip install -e . 29 | ``` 30 | ## EvalAlign Dataset 31 | The human feedback dataset on evaluating synthesized images, which is also the finetuning data of the EvalAlign evaluation models, has been released on [Huggingface](https://huggingface.co/datasets/Fudan-FUXI/EvalAlign-datasets). 32 | 33 | 34 | ## EvalAlign Weights 35 | We provide two version of EvalAlign evaluation models on huggingface: 36 | 37 | [EvalAlign-v1.0-13B](https://huggingface.co/Fudan-FUXI/evalalign-v1.0-13b) 38 | 39 | [EvalAlign-v1.0-34B](https://huggingface.co/Fudan-FUXI/evalalign-v1.0-34b) 40 | 41 | If you have sufficient computational resources, we strongly recommend using EvalAlign-v1.0-34B for superior evaluation performance. However, if resources are limited, the 13B version of EvalAlign-v1.0 also provides acceptable evaluation capabilities. 42 | ## Evaluation 43 | ### Image Faithfulness evaluation 44 | You must use the [prompt](https://github.com/SAIS-FUXI/EvalAlign/tree/main/configs/test_faithfulness_prompt.json) provide about faithfulness to generate some images on your own model or open source model.The file name of the image needs to be consistent with prompt_id. 45 | ```shell 46 | { 47 | "prompt_id": "259_2_2", 48 | "prompt": "A young man was painting a beautiful landscape with a yellow brush and a black canvas." 49 | } 50 | ``` 51 | For example, in this data, you generated an image using prompt and named it "259_2_2. jpg". 52 | ```shell 53 | #Run script 54 | ./scripts/inference_faithfulness.sh 55 | ``` 56 | You need to modify the path in the script 57 | ```shell 58 | CUDA_VISIBLE_DEVICES=0 python evalalign/eval/test_faithfulness.py \ 59 | --model-path Fudan-FUXI/evalalign-v1.0-13b \ # Downloaded model weights 60 | --images-dir ./PixArt-XL-2-1024-MS \ # The folder for generating images 61 | --output-dir ./results_faithfulness 62 | ``` 63 | - result faithfulness 64 | 65 | You will get a body, hand,face,object, common, The scores of the five dimensions and the average score of the overall model 66 | ```shell 67 | { 68 | "body_score": 217, 69 | "body_num": 100, 70 | "body_average": 2.17, 71 | "Hand_score": 60, 72 | "Hand_num": 89, 73 | "Hand_average": 0.6741573033707865, 74 | "face_score": 137, 75 | "face_num": 81, 76 | "face_average": 1.691358024691358, 77 | "object_score": 250, 78 | "object_num": 100, 79 | "object_average": 2.5, 80 | "common_score": 105, 81 | "common_num": 100, 82 | "common_average": 1.05, 83 | "total_score": 769, 84 | "num": 470, 85 | "avg_score": 1.6361702127659574 86 | } 87 | ``` 88 | ### Text-to-Image Alignment evaluation 89 | Same as Faithfulness.You must use the [prompt](https://github.com/SAIS-FUXI/EvalAlign/tree/main/configs/test_alignment_prompt.json) provide about faithfulness to generate some images on your own model or open source model.The file name of the image needs to be consistent with prompt_id. 90 | ```shell 91 | { 92 | "prompt_id": "99", 93 | "prompt": "two refrigerators stand side-by-side in a kitchen, with two potted plants on either side of them." 94 | } 95 | ``` 96 | For example, in this data, you generated an image using prompt and named it "99. jpg". 97 | ```shell 98 | #Run script 99 | ./scripts/inference_alignment.sh 100 | ``` 101 | You need to modify the path in the script 102 | ```shell 103 | CUDA_VISIBLE_DEVICES=0 python evalalign/eval/test_faithfulness.py \ 104 | --model-path Fudan-FUXI/evalalign-v1.0-13b \ # Downloaded model weights 105 | --images-dir ./IF-I-XL-v1.0 \ # The folder for generating images 106 | --output-dir ./results_alignment 107 | ``` 108 | - result faithfulness 109 | You will get a Object, Count,Spatial,Action, Color, Style.The scores of the six dimensions and the average score of the overall model. 110 | ```shell 111 | { 112 | "Object_score": 209, 113 | "Object_num": 118, 114 | "Object_avgerage": 1.771186440677966, 115 | "Count_score": 160, 116 | "Count_num": 109, 117 | "Count_avgerage": 1.4678899082568808, 118 | "Spatial_score": 155, 119 | "Spatial_num": 85, 120 | "Spatial_avgerage": 1.8235294117647058, 121 | "Action_score": 102, 122 | "Action_num": 54, 123 | "Action_avgerage": 1.8888888888888888, 124 | "Color_score": 51, 125 | "Color_num": 26, 126 | "Color_avgerage": 1.9615384615384615, 127 | "Style_score": 50, 128 | "Style_num": 25, 129 | "Style_avgerage": 2.0, 130 | "total_score": 727, 131 | "total_avgerage": 6.058333333333334 132 | } 133 | ``` 134 | 135 | 136 | 137 | ## Citation 138 | ```bibtex 139 | @article{tan2024evalalign, 140 | title={EVALALIGN: Supervised Fine-Tuning Multimodal LLMs with Human-Aligned Data for Evaluating Text-to-Image Models}, 141 | author={Tan, Zhiyu and Yang, Xiaomeng and Qin, Luozheng and Yang, Mengping and Zhang, Cheng and Li, Hao}, 142 | journal={arXiv preprint arXiv:2406.16562}, 143 | year={2024}, 144 | institution={Shanghai Academy of AI for Science and Carnegie Mellon University and Fudan University}, 145 | } 146 | ``` 147 | ## Acknowledgement 148 | - [Llava](https://github.com/haotian-liu/LLaVA): Our model is trained on llava and has excellent multimodal reasoning ability! 149 | 150 | -------------------------------------------------------------------------------- /configs/test_faithfulness_prompt.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt_id": "1852", 4 | "prompt": "A woman is practicing calligraphy at her desk." 5 | }, 6 | { 7 | "prompt_id": "380_v1_172_157", 8 | "prompt": "Three firefighters bravely fought to extinguish a raging fire, while two paramedics tended to the injured individuals nearby." 9 | }, 10 | { 11 | "prompt_id": "304", 12 | "prompt": "A man carefully folds a white piece of paper and places it in a black wastebasket." 13 | }, 14 | { 15 | "prompt_id": "1681_1681_v2", 16 | "prompt": "A person who is larger than a car and taller than a horse, but smaller than a cat." 17 | }, 18 | { 19 | "prompt_id": "1726", 20 | "prompt": "A person who is taller than a horse and wider than a car, but smaller than a cat." 21 | }, 22 | { 23 | "prompt_id": "745_745_v2", 24 | "prompt": "At the local tennis court, a group of friends were engaged in a friendly game of doubles. Meanwhile, a jogger was running laps around the court, taking advantage of the open space for his exercise routine. Nearby, a young boy was chasing after a stray ball that had rolled off the court, eager to return it to the players." 25 | }, 26 | { 27 | "prompt_id": "1421", 28 | "prompt": "A group of friends are gathered around a campfire, sharing spooky stories." 29 | }, 30 | { 31 | "prompt_id": "2060", 32 | "prompt": "A woman is carrying a basket of dirty clothes and walking towards the washing machine." 33 | }, 34 | { 35 | "prompt_id": "1466", 36 | "prompt": "A couple, John and Sarah, were taking a romantic stroll on the beach, holding hands and watching the waves." 37 | }, 38 | { 39 | "prompt_id": "3281", 40 | "prompt": "A man stood atop a mountain peak, gazing out at the breathtaking view." 41 | }, 42 | { 43 | "prompt_id": "1365", 44 | "prompt": "A man is playing a guitar and singing a love song to his girlfriend." 45 | }, 46 | { 47 | "prompt_id": "719", 48 | "prompt": "As I stood in front of the mirror, I admired the smooth, polished surface that reflected my image back to me. It was a perfect reflection of myself, and I couldn't help but smile at the sight." 49 | }, 50 | { 51 | "prompt_id": "707", 52 | "prompt": "A man wearing a hat is walking alongside a woman holding an umbrella. They are observing their children playing in the park." 53 | }, 54 | { 55 | "prompt_id": "380", 56 | "prompt": "A man with a skateboard in hand nimbly leaped over a puddle, deftly avoiding a pedestrian, and continued on his journey." 57 | }, 58 | { 59 | "prompt_id": "1257", 60 | "prompt": "A person is sitting at a desk, holding a laptop and working on a project." 61 | }, 62 | { 63 | "prompt_id": "2381", 64 | "prompt": "A woman was walking along the beach when she noticed a turtle struggling to make its way back to the ocean. She decided to help the turtle by gently lifting it onto its back and guiding it towards the water." 65 | }, 66 | { 67 | "prompt_id": "56_v2_10_10", 68 | "prompt": "A young boy, excited to play near the river, jumps away from the fence and runs towards the water's edge." 69 | }, 70 | { 71 | "prompt_id": "1560", 72 | "prompt": "A banana is larger than a cat and smaller than a person." 73 | }, 74 | { 75 | "prompt_id": "1251357", 76 | "prompt": "A woman with closed eyes, lost in her intuition, stands in a moonlit night, surrounded by an occult style and a dark background." 77 | }, 78 | { 79 | "prompt_id": "231_438_211", 80 | "prompt": "A woman wearing a scarf sits on a train, opposite a man holding a briefcase." 81 | }, 82 | { 83 | "prompt_id": "673_673_v2", 84 | "prompt": "A group of friends are gathered around a table, playing a lively game of poker. As the cards are dealt and the bets are placed, their loyal dog sits nearby, watching the game with great interest. Its tail wags excitedly, as if it too is eager to join in the fun." 85 | }, 86 | { 87 | "prompt_id": "1430", 88 | "prompt": "A man is enjoying a round of golf at a picturesque golf course, surrounded by lush greenery and stunning views." 89 | }, 90 | { 91 | "prompt_id": "231_229_110", 92 | "prompt": "A man wearing a cowboy hat was standing behind another man wearing a baseball cap." 93 | }, 94 | { 95 | "prompt_id": "50_v2_243_115", 96 | "prompt": "A young man with black hair is strolling leisurely, while his friend with blonde hair is jogging at a quicker pace." 97 | }, 98 | { 99 | "prompt_id": "3115", 100 | "prompt": "A man stood on the side of a clock tower, observing the bustling city below." 101 | }, 102 | { 103 | "prompt_id": "1146", 104 | "prompt": "A man is practicing his jump shot with a basketball in his hands." 105 | }, 106 | { 107 | "prompt_id": "331", 108 | "prompt": "The woman passionately kissed her beloved dog, hugged it tightly, and whispered sweet nothings into its ear." 109 | }, 110 | { 111 | "prompt_id": "273_v2_9_8", 112 | "prompt": "A young girl wearing a purple dress is holding a pink balloon as she walks down the street." 113 | }, 114 | { 115 | "prompt_id": "26", 116 | "prompt": "A woman is sitting at a table with a bowl of food in front of her. She picks up a spoon and begins to feed a cow that is standing nearby." 117 | }, 118 | { 119 | "prompt_id": "1992", 120 | "prompt": "A man is practicing drums with his garage band." 121 | }, 122 | { 123 | "prompt_id": "154", 124 | "prompt": "A man is walking along the beach, carrying a surfboard under his arm, while a woman is standing on the sand, skillfully flying a colorful kite in the wind." 125 | }, 126 | { 127 | "prompt_id": "357_189_131", 128 | "prompt": "A hiker with a backpack is trekking through the forest, while a woman with a purse is following behind." 129 | }, 130 | { 131 | "prompt_id": "284", 132 | "prompt": "A woman is brushing her hair while a man is reading a book." 133 | }, 134 | { 135 | "prompt_id": "3665", 136 | "prompt": "A bee landed on the top of a boy's head while he was playing outside." 137 | }, 138 | { 139 | "prompt_id": "636", 140 | "prompt": "In the room, people were singing along to the music, while a child clapped their hands to the beat and a dog barked in time with the rhythm." 141 | }, 142 | { 143 | "prompt_id": "341_v2_308_195", 144 | "prompt": "A woman wearing a yellow raincoat held her umbrella high in the air, while a man wearing a gray coat stood nearby." 145 | }, 146 | { 147 | "prompt_id": "2599", 148 | "prompt": "A woman was sitting at a table in a restaurant, enjoying her meal." 149 | }, 150 | { 151 | "prompt_id": "238", 152 | "prompt": "A woman is hugging her teddy bears, while a man is standing on a dock by the water." 153 | }, 154 | { 155 | "prompt_id": "1921", 156 | "prompt": "A person is admiring a cherry blossom tree in full bloom, captivated by its delicate pink flowers and the gentle breeze that rustles through its branches." 157 | }, 158 | { 159 | "prompt_id": "1807_1807_v2", 160 | "prompt": "A man wearing sunglasses is watching a parade with his family." 161 | }, 162 | { 163 | "prompt_id": "2961", 164 | "prompt": "A man was sitting at a table on the right side of the room, enjoying a meal with his family." 165 | }, 166 | { 167 | "prompt_id": "2743", 168 | "prompt": "A woman was walking on the right side of a busy street, carrying a backpack on her shoulder." 169 | }, 170 | { 171 | "prompt_id": "16171308", 172 | "prompt": "I attended a beach party with a group of friends, including a stunning woman who caught everyone's attention. The event was beautifully lit with colorful lanterns and fairy lights, creating a magical atmosphere. The party was well-organized, with a variety of activities and games for everyone to enjoy. The food was delicious, and the drinks were flowing. The music was lively, and everyone was dancing and having a great time. The night was filled with laughter, good conversation, and unforgettable memories." 173 | }, 174 | { 175 | "prompt_id": "1811", 176 | "prompt": "A person who is taller than a cat and a car, but shorter than a dog." 177 | }, 178 | { 179 | "prompt_id": "1144", 180 | "prompt": "A man is wearing a thick winter coat and walking through the snow, his boots crunching on the icy ground." 181 | }, 182 | { 183 | "prompt_id": "76", 184 | "prompt": "A man is playing with his dog, tossing a frisbee in the air. The dog catches the frisbee and brings it back to the man, who then pets the dog affectionately." 185 | }, 186 | { 187 | "prompt_id": "2657", 188 | "prompt": "A woman was standing on the left side of a clock tower, admiring the intricate design and the way the sunlight reflected off the metal." 189 | }, 190 | { 191 | "prompt_id": "258", 192 | "prompt": "A man is reading a book while sitting on a bench in the park, and nearby, an elephant is enjoying a refreshing dip in the pond." 193 | }, 194 | { 195 | "prompt_id": "2736", 196 | "prompt": "A woman is sitting at a table, and on the left side of her is a bowl filled with fruit." 197 | }, 198 | { 199 | "prompt_id": "67_63_61", 200 | "prompt": "A man wearing a yellow hat is jogging, while a woman in a blue dress is strolling." 201 | }, 202 | { 203 | "prompt_id": "1115", 204 | "prompt": "A man in a green chair, a woman driving a blue car, and a child wearing a red shirt." 205 | }, 206 | { 207 | "prompt_id": "8", 208 | "prompt": "The man held the black remote in his right hand, while his left hand remained empty." 209 | }, 210 | { 211 | "prompt_id": "907_907_v2", 212 | "prompt": "The worn, leather saddle was a symbol of the cowboy's hardworking life." 213 | }, 214 | { 215 | "prompt_id": "220", 216 | "prompt": "A man and a woman are sitting on a bench in a park, enjoying the warm sunshine. Their dog, a friendly golden retriever, is lying at their feet, contentedly wagging its tail. Nearby, a farmer is milking a cow in a nearby field, providing fresh milk for the local community." 217 | }, 218 | { 219 | "prompt_id": "1614", 220 | "prompt": "A miniature horse, smaller than a banana, is grazing in a field while a person nearby watches in amazement." 221 | }, 222 | { 223 | "prompt_id": "341_v2_502_311", 224 | "prompt": "A magician wearing a black cape and holding a wand in the air, performs a trick with a rabbit wearing a hat nearby." 225 | }, 226 | { 227 | "prompt_id": "2968", 228 | "prompt": "A person on the right side of a train was enjoying the scenic view as it passed through the countryside." 229 | }, 230 | { 231 | "prompt_id": "36", 232 | "prompt": "A man is playing with his dog, tossing a frisbee for it to catch. Meanwhile, another man is milking a cow on a nearby farm." 233 | }, 234 | { 235 | "prompt_id": "279_220_160", 236 | "prompt": "The man on the left is sprinting quickly, while the woman on the right is strolling leisurely." 237 | }, 238 | { 239 | "prompt_id": "341_v2_266_163", 240 | "prompt": "A fisherman wearing waders and holding a fishing rod in the air, and a hunter wearing camouflage gear, were both enjoying their respective outdoor activities." 241 | }, 242 | { 243 | "prompt_id": "183", 244 | "prompt": "A man accidentally drops his tennis racket while playing a game, and another man nearby is holding his glasses, which he had taken off to watch the match." 245 | }, 246 | { 247 | "prompt_id": "126_v2_469_754", 248 | "prompt": "The shorter boy on the left is standing next to his taller friend on the right." 249 | }, 250 | { 251 | "prompt_id": "1260", 252 | "prompt": "A yellow person, a green banana, a red dog, and an orange airplane were all present at the park. The person was enjoying a sunny day, the banana was being peeled and eaten, the dog was playing fetch, and the airplane was flying overhead." 253 | }, 254 | { 255 | "prompt_id": "3141", 256 | "prompt": "A person on the right of a refrigerator is likely standing in a kitchen, preparing a meal or looking for something to eat." 257 | }, 258 | { 259 | "prompt_id": "144_144_144_v2", 260 | "prompt": "A man is holding a fork, while bananas are in his other hand." 261 | }, 262 | { 263 | "prompt_id": "1048", 264 | "prompt": "A person is gazing at the full moon, feeling a surge of inspiration." 265 | }, 266 | { 267 | "prompt_id": "1067_1067_v2", 268 | "prompt": "A family is enjoying a picnic in the park on a sunny afternoon." 269 | }, 270 | { 271 | "prompt_id": "328_241_218", 272 | "prompt": "A young girl wearing a red hat playfully jumps over her friend, who is not wearing a hat." 273 | }, 274 | { 275 | "prompt_id": "3573", 276 | "prompt": "A man walked his dog on the right side of the sidewalk." 277 | }, 278 | { 279 | "prompt_id": "1463", 280 | "prompt": "A team of colleagues is collaborating on a new project, sharing ideas and discussing potential solutions." 281 | }, 282 | { 283 | "prompt_id": "1004200", 284 | "prompt": "A young black man with a stylish black power haircut and headphones on, standing confidently in a full-body Pixar-style pose." 285 | }, 286 | { 287 | "prompt_id": "328_38_37", 288 | "prompt": "A young man carrying a backpack approaches another young man who is not carrying a backpack." 289 | }, 290 | { 291 | "prompt_id": "1932", 292 | "prompt": "A person is standing on a diving board, preparing to make a splash." 293 | }, 294 | { 295 | "prompt_id": "89", 296 | "prompt": "A man is sitting by a pond, drinking water from a bottle, while a bird perches on a nearby fence, watching him curiously." 297 | }, 298 | { 299 | "prompt_id": "963", 300 | "prompt": "A cheerful orange cat and a friendly red-haired person were enjoying a sunny day in the park." 301 | }, 302 | { 303 | "prompt_id": "273_v2_234_221", 304 | "prompt": "A salesperson in a blue shirt is holding a green pen, ready to take notes during a meeting with a potential client." 305 | }, 306 | { 307 | "prompt_id": "1253", 308 | "prompt": "A dog is sitting on its owner's lap, enjoying the scenic view as they take a leisurely drive through the countryside." 309 | }, 310 | { 311 | "prompt_id": "981", 312 | "prompt": "A young woman wearing a green dress and a yellow Labrador retriever were walking together in the park." 313 | }, 314 | { 315 | "prompt_id": "2991", 316 | "prompt": "A young girl was standing on the left side of a pig, which was being led by a farmer." 317 | }, 318 | { 319 | "prompt_id": "921", 320 | "prompt": "A woman is watering her garden nearby, a boy is flying a kite in the sky, and a dog is running around the tree where a man is cutting wood." 321 | }, 322 | { 323 | "prompt_id": "109_v2_105_92", 324 | "prompt": "The businessman, dressed in a brown suit and gray tie, was attending a business meeting." 325 | }, 326 | { 327 | "prompt_id": "755", 328 | "prompt": "A man is sitting at a table, holding a laptop and typing away. A woman is seated next to him, typing on a keyboard. A child is playing a game on a tablet, while another man is checking his phone." 329 | }, 330 | { 331 | "prompt_id": "273_v2_27_24", 332 | "prompt": "A woman wearing a black blouse is sitting in her living room, skillfully knitting a beautiful purple scarf." 333 | }, 334 | { 335 | "prompt_id": "41", 336 | "prompt": "A man is playing the piano while a woman is riding a skateboard." 337 | }, 338 | { 339 | "prompt_id": "1767", 340 | "prompt": "A man is driving a convertible down a picturesque country road, enjoying the fresh air and beautiful scenery." 341 | }, 342 | { 343 | "prompt_id": "2121", 344 | "prompt": "A man is sitting on a park bench, observing passersby as they stroll by." 345 | }, 346 | { 347 | "prompt_id": "772", 348 | "prompt": "A man is holding a cat, while a woman stands nearby, admiring the feline. A boy plays with a ball nearby, and a dog barks in the background." 349 | }, 350 | { 351 | "prompt_id": "270_38_20", 352 | "prompt": "The child giggles with joy as they play with their favorite toy, while the adult nearby wipes away tears of sadness, reflecting on a recent loss." 353 | }, 354 | { 355 | "prompt_id": "56_v2_377_288", 356 | "prompt": "A baby crawls out of their crib and makes their way towards the colorful toys on the floor." 357 | }, 358 | { 359 | "prompt_id": "357_378_221", 360 | "prompt": "A bearded man is playing guitar in front of a woman who is not bearded." 361 | }, 362 | { 363 | "prompt_id": "1633_1633_v2", 364 | "prompt": "A man named John was taking his dog, Max, for a walk in the park. They both enjoyed the fresh air and the beautiful scenery around them." 365 | }, 366 | { 367 | "prompt_id": "1110", 368 | "prompt": "A blue dog, a red car, and an orange person were all present at the busy intersection. The dog was eagerly waiting for its owner to return from the store, the car was stuck in traffic, and the person was trying to cross the street." 369 | }, 370 | { 371 | "prompt_id": "1544", 372 | "prompt": "A tourist is visiting a renowned historical site in a foreign country." 373 | }, 374 | { 375 | "prompt_id": "1655", 376 | "prompt": "A dog patiently waits on a mat for its owner to come home." 377 | }, 378 | { 379 | "prompt_id": "3120", 380 | "prompt": "A man was sitting at a table, with a bowl of soup placed on the left side of him." 381 | }, 382 | { 383 | "prompt_id": "273_v2_335_310", 384 | "prompt": "A woman wearing a yellow dress is holding a purple umbrella to protect herself from the rain." 385 | }, 386 | { 387 | "prompt_id": "3336", 388 | "prompt": "A woman sat on the left side of the table, sipping her coffee and chatting with her friends." 389 | }, 390 | { 391 | "prompt_id": "848", 392 | "prompt": "At a gathering, some friends were playing cards while a dog was eagerly devouring a slice of pizza. In the corner, a cat was peacefully napping, undisturbed by the commotion." 393 | }, 394 | { 395 | "prompt_id": "1208", 396 | "prompt": "A dog is playing tug-of-war with its owner, who is also wagging its tail." 397 | }, 398 | { 399 | "prompt_id": "418", 400 | "prompt": "The man gently patted his loyal dog, affectionately scratching its head, and smiled as they both enjoyed a peaceful moment together." 401 | } 402 | ] -------------------------------------------------------------------------------- /evalalign/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /evalalign/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /evalalign/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | MPT = auto() 14 | PLAIN = auto() 15 | LLAMA_2 = auto() 16 | 17 | 18 | @dataclasses.dataclass 19 | class Conversation: 20 | """A class that keeps all conversation history.""" 21 | system: str 22 | roles: List[str] 23 | messages: List[List[str]] 24 | offset: int 25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 26 | sep: str = "###" 27 | sep2: str = None 28 | version: str = "Unknown" 29 | 30 | skip_next: bool = False 31 | 32 | def get_prompt(self): 33 | messages = self.messages 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: 35 | messages = self.messages.copy() 36 | init_role, init_msg = messages[0].copy() 37 | init_msg = init_msg[0].replace("", "").strip() 38 | if 'mmtag' in self.version: 39 | messages[0] = (init_role, init_msg) 40 | messages.insert(0, (self.roles[0], "")) 41 | messages.insert(1, (self.roles[1], "Received.")) 42 | else: 43 | messages[0] = (init_role, "\n" + init_msg) 44 | 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.MPT: 65 | ret = self.system + self.sep 66 | for role, message in messages: 67 | if message: 68 | if type(message) is tuple: 69 | message, _, _ = message 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: 74 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 75 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 76 | ret = "" 77 | 78 | for i, (role, message) in enumerate(messages): 79 | if i == 0: 80 | assert message, "first message should not be none" 81 | assert role == self.roles[0], "first message should come from user" 82 | if message: 83 | if type(message) is tuple: 84 | message, _, _ = message 85 | if i == 0: message = wrap_sys(self.system) + message 86 | if i % 2 == 0: 87 | message = wrap_inst(message) 88 | ret += self.sep + message 89 | else: 90 | ret += " " + message + " " + self.sep2 91 | else: 92 | ret += "" 93 | ret = ret.lstrip(self.sep) 94 | elif self.sep_style == SeparatorStyle.PLAIN: 95 | seps = [self.sep, self.sep2] 96 | ret = self.system 97 | for i, (role, message) in enumerate(messages): 98 | if message: 99 | if type(message) is tuple: 100 | message, _, _ = message 101 | ret += message + seps[i % 2] 102 | else: 103 | ret += "" 104 | else: 105 | raise ValueError(f"Invalid style: {self.sep_style}") 106 | 107 | return ret 108 | 109 | def append_message(self, role, message): 110 | self.messages.append([role, message]) 111 | 112 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 113 | if image_process_mode == "Pad": 114 | def expand2square(pil_img, background_color=(122, 116, 104)): 115 | width, height = pil_img.size 116 | if width == height: 117 | return pil_img 118 | elif width > height: 119 | result = Image.new(pil_img.mode, (width, width), background_color) 120 | result.paste(pil_img, (0, (width - height) // 2)) 121 | return result 122 | else: 123 | result = Image.new(pil_img.mode, (height, height), background_color) 124 | result.paste(pil_img, ((height - width) // 2, 0)) 125 | return result 126 | image = expand2square(image) 127 | elif image_process_mode in ["Default", "Crop"]: 128 | pass 129 | elif image_process_mode == "Resize": 130 | image = image.resize((336, 336)) 131 | else: 132 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 133 | if max(image.size) > max_len: 134 | max_hw, min_hw = max(image.size), min(image.size) 135 | aspect_ratio = max_hw / min_hw 136 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 137 | longest_edge = int(shortest_edge * aspect_ratio) 138 | W, H = image.size 139 | if H > W: 140 | H, W = longest_edge, shortest_edge 141 | else: 142 | H, W = shortest_edge, longest_edge 143 | image = image.resize((W, H)) 144 | if return_pil: 145 | return image 146 | else: 147 | buffered = BytesIO() 148 | image.save(buffered, format=image_format) 149 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 150 | return img_b64_str 151 | 152 | def get_images(self, return_pil=False): 153 | images = [] 154 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 155 | if i % 2 == 0: 156 | if type(msg) is tuple: 157 | msg, image, image_process_mode = msg 158 | image = self.process_image(image, image_process_mode, return_pil=return_pil) 159 | images.append(image) 160 | return images 161 | 162 | def to_gradio_chatbot(self): 163 | ret = [] 164 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 165 | if i % 2 == 0: 166 | if type(msg) is tuple: 167 | msg, image, image_process_mode = msg 168 | img_b64_str = self.process_image( 169 | image, "Default", return_pil=False, 170 | image_format='JPEG') 171 | img_str = f'user upload image' 172 | msg = img_str + msg.replace('', '').strip() 173 | ret.append([msg, None]) 174 | else: 175 | ret.append([msg, None]) 176 | else: 177 | ret[-1][-1] = msg 178 | return ret 179 | 180 | def copy(self): 181 | return Conversation( 182 | system=self.system, 183 | roles=self.roles, 184 | messages=[[x, y] for x, y in self.messages], 185 | offset=self.offset, 186 | sep_style=self.sep_style, 187 | sep=self.sep, 188 | sep2=self.sep2, 189 | version=self.version) 190 | 191 | def dict(self): 192 | if len(self.get_images()) > 0: 193 | return { 194 | "system": self.system, 195 | "roles": self.roles, 196 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 197 | "offset": self.offset, 198 | "sep": self.sep, 199 | "sep2": self.sep2, 200 | } 201 | return { 202 | "system": self.system, 203 | "roles": self.roles, 204 | "messages": self.messages, 205 | "offset": self.offset, 206 | "sep": self.sep, 207 | "sep2": self.sep2, 208 | } 209 | 210 | 211 | conv_vicuna_v0 = Conversation( 212 | system="A chat between a curious human and an artificial intelligence assistant. " 213 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 214 | roles=("Human", "Assistant"), 215 | messages=( 216 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 217 | ("Assistant", 218 | "Renewable energy sources are those that can be replenished naturally in a relatively " 219 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 220 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 221 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 222 | "renewable and non-renewable energy sources:\n" 223 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 224 | "energy sources are finite and will eventually run out.\n" 225 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 226 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 227 | "and other negative effects.\n" 228 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 229 | "have lower operational costs than non-renewable sources.\n" 230 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 231 | "locations than non-renewable sources.\n" 232 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 233 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 234 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 235 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 236 | ), 237 | offset=2, 238 | sep_style=SeparatorStyle.SINGLE, 239 | sep="###", 240 | ) 241 | 242 | conv_vicuna_v1 = Conversation( 243 | system="A chat between a curious user and an artificial intelligence assistant. " 244 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 245 | roles=("USER", "ASSISTANT"), 246 | version="v1", 247 | messages=(), 248 | offset=0, 249 | sep_style=SeparatorStyle.TWO, 250 | sep=" ", 251 | sep2="", 252 | ) 253 | 254 | conv_llama_2 = Conversation( 255 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 256 | 257 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 258 | roles=("USER", "ASSISTANT"), 259 | version="llama_v2", 260 | messages=(), 261 | offset=0, 262 | sep_style=SeparatorStyle.LLAMA_2, 263 | sep="", 264 | sep2="", 265 | ) 266 | 267 | conv_llava_llama_2 = Conversation( 268 | system="You are a helpful language and vision assistant. " 269 | "You are able to understand the visual content that the user provides, " 270 | "and assist the user with a variety of tasks using natural language.", 271 | roles=("USER", "ASSISTANT"), 272 | version="llama_v2", 273 | messages=(), 274 | offset=0, 275 | sep_style=SeparatorStyle.LLAMA_2, 276 | sep="", 277 | sep2="", 278 | ) 279 | 280 | conv_mpt = Conversation( 281 | system="""<|im_start|>system 282 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 283 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 284 | version="mpt", 285 | messages=(), 286 | offset=0, 287 | sep_style=SeparatorStyle.MPT, 288 | sep="<|im_end|>", 289 | ) 290 | 291 | conv_llava_plain = Conversation( 292 | system="", 293 | roles=("", ""), 294 | messages=( 295 | ), 296 | offset=0, 297 | sep_style=SeparatorStyle.PLAIN, 298 | sep="\n", 299 | ) 300 | 301 | conv_llava_v0 = Conversation( 302 | system="A chat between a curious human and an artificial intelligence assistant. " 303 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 304 | roles=("Human", "Assistant"), 305 | messages=( 306 | ), 307 | offset=0, 308 | sep_style=SeparatorStyle.SINGLE, 309 | sep="###", 310 | ) 311 | 312 | conv_llava_v0_mmtag = Conversation( 313 | system="A chat between a curious user and an artificial intelligence assistant. " 314 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 315 | "The visual content will be provided with the following format: visual content.", 316 | roles=("Human", "Assistant"), 317 | messages=( 318 | ), 319 | offset=0, 320 | sep_style=SeparatorStyle.SINGLE, 321 | sep="###", 322 | version="v0_mmtag", 323 | ) 324 | 325 | conv_llava_v1 = Conversation( 326 | system="A chat between a curious human and an artificial intelligence assistant. " 327 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 328 | roles=("USER", "ASSISTANT"), 329 | version="v1", 330 | messages=(), 331 | offset=0, 332 | sep_style=SeparatorStyle.TWO, 333 | sep=" ", 334 | sep2="", 335 | ) 336 | 337 | conv_llava_v1_mmtag = Conversation( 338 | system="A chat between a curious user and an artificial intelligence assistant. " 339 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 340 | "The visual content will be provided with the following format: visual content.", 341 | roles=("USER", "ASSISTANT"), 342 | messages=(), 343 | offset=0, 344 | sep_style=SeparatorStyle.TWO, 345 | sep=" ", 346 | sep2="", 347 | version="v1_mmtag", 348 | ) 349 | 350 | conv_mistral_instruct = Conversation( 351 | system="", 352 | roles=("USER", "ASSISTANT"), 353 | version="llama_v2", 354 | messages=(), 355 | offset=0, 356 | sep_style=SeparatorStyle.LLAMA_2, 357 | sep="", 358 | sep2="", 359 | ) 360 | 361 | 362 | conv_chatml_direct = Conversation( 363 | system="""<|im_start|>system 364 | Answer the questions.""", 365 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 366 | version="mpt", 367 | messages=(), 368 | offset=0, 369 | sep_style=SeparatorStyle.MPT, 370 | sep="<|im_end|>", 371 | ) 372 | 373 | default_conversation = conv_vicuna_v1 374 | conv_templates = { 375 | "default": conv_vicuna_v0, 376 | "v0": conv_vicuna_v0, 377 | "v1": conv_vicuna_v1, 378 | "vicuna_v1": conv_vicuna_v1, 379 | "llama_2": conv_llama_2, 380 | "mistral_instruct": conv_mistral_instruct, 381 | "chatml_direct": conv_chatml_direct, 382 | "mistral_direct": conv_chatml_direct, 383 | 384 | "plain": conv_llava_plain, 385 | "v0_plain": conv_llava_plain, 386 | "llava_v0": conv_llava_v0, 387 | "v0_mmtag": conv_llava_v0_mmtag, 388 | "llava_v1": conv_llava_v1, 389 | "v1_mmtag": conv_llava_v1_mmtag, 390 | "llava_llama_2": conv_llava_llama_2, 391 | 392 | "mpt": conv_mpt, 393 | } 394 | 395 | 396 | if __name__ == "__main__": 397 | print(default_conversation.get_prompt()) 398 | -------------------------------------------------------------------------------- /evalalign/eval/test_alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pandas as pd 5 | from evalalign.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 6 | 7 | import argparse 8 | 9 | from evalalign.constants import ( 10 | IMAGE_TOKEN_INDEX, 11 | DEFAULT_IMAGE_TOKEN, 12 | DEFAULT_IM_START_TOKEN, 13 | DEFAULT_IM_END_TOKEN, 14 | IMAGE_PLACEHOLDER, 15 | ) 16 | from evalalign.conversation import conv_templates, SeparatorStyle 17 | from evalalign.model.builder import load_pretrained_model 18 | from evalalign.utils import disable_torch_init 19 | from evalalign.mm_utils import ( 20 | process_images, 21 | tokenizer_image_token, 22 | get_model_name_from_path, 23 | ) 24 | 25 | from PIL import Image 26 | 27 | import requests 28 | from io import BytesIO 29 | import re 30 | 31 | 32 | def image_parser(args): 33 | out = args.image_file.split(args.sep) 34 | return out 35 | 36 | 37 | def load_image(image_file): 38 | if image_file.startswith("http") or image_file.startswith("https"): 39 | response = requests.get(image_file) 40 | image = Image.open(BytesIO(response.content)).convert("RGB") 41 | else: 42 | image = Image.open(image_file).convert("RGB") 43 | return image 44 | 45 | 46 | def eval_model(args): 47 | # Model 48 | disable_torch_init() 49 | 50 | prompt_json = "./configs/test_alignment_prompt.json" 51 | id2prompt = {} 52 | with open(prompt_json,"r") as f: 53 | datasets_prompt = json.load(f) 54 | for data in datasets_prompt: 55 | prompts_id = data["prompt_id"] 56 | id2prompt[prompts_id] = data["question_info"] 57 | 58 | model_name = args.model_path.split("/")[-1] 59 | 60 | tokenizer, model, image_processor, context_len = load_pretrained_model( 61 | args.model_path, None, model_name 62 | ) 63 | 64 | img_nums = len(os.listdir(args.images_dir)) 65 | #i = 0 66 | sum_res = {} 67 | for p in os.listdir(args.images_dir): 68 | 69 | imgp = os.path.join(args.images_dir,p) 70 | prompt_id = p.split(".")[0] 71 | 72 | if prompt_id not in id2prompt: 73 | raise ValueError("prompt_id must be used to name the generated image! The name of the image must match the promt_id!") 74 | question_info = id2prompt[prompt_id] 75 | ans_typed = {} 76 | 77 | for typed, question in question_info.items(): 78 | 79 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 80 | if IMAGE_PLACEHOLDER in question: 81 | if model.config.mm_use_im_start_end: 82 | question = re.sub(IMAGE_PLACEHOLDER, image_token_se, question) 83 | else: 84 | question = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, question) 85 | else: 86 | if model.config.mm_use_im_start_end: 87 | question = image_token_se + "\n" + question 88 | else: 89 | question = question 90 | 91 | conv_mode = "v1" 92 | 93 | if "evalalign-v1.0-34b" in model_name: 94 | conv_mode = "mpt" 95 | elif "evalalign-v1.0-13b" in model_name: 96 | conv_mode = "v1" 97 | 98 | conv = conv_templates[conv_mode].copy() 99 | conv.append_message(conv.roles[0], question) 100 | conv.append_message(conv.roles[1], None) 101 | prompt = conv.get_prompt() 102 | 103 | images = [load_image(imgp)] 104 | 105 | images_tensor = process_images( 106 | images, 107 | image_processor, 108 | model.config 109 | ).to(model.device, dtype=torch.float16) 110 | 111 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device) 112 | #print(input_ids) 113 | with torch.inference_mode(): 114 | output_ids = model.generate( 115 | input_ids, 116 | images=images_tensor, 117 | #image_sizes=image_sizes, 118 | do_sample=False, 119 | temperature=args.temperature, 120 | top_p=args.top_p, 121 | num_beams=args.num_beams, 122 | max_new_tokens=args.max_new_tokens, 123 | use_cache=True, 124 | ) 125 | output = tokenizer.decode(output_ids[0][1:]).strip() 126 | print("output",output) 127 | scores = re.findall(r'\d+', output) 128 | if len(scores) > 0: 129 | score = int(scores[0]) 130 | else: 131 | score = 0 132 | ans_typed[typed] = score 133 | if typed not in sum_res: 134 | sum_res[typed] = [score] 135 | else: 136 | sum_res[typed].append(score) 137 | 138 | grained_dict = {} 139 | sums = [] 140 | for td, tdres in sum_res.items(): 141 | grained_dict[f"{td}_score"] = sum(tdres) 142 | grained_dict[f"{td}_num"] = len(tdres) 143 | grained_dict[f"{td}_avgerage"] = sum(tdres)/len(tdres) if len(tdres) >0 else 0 144 | 145 | sums.append(sum(tdres)) 146 | print(td, sum(tdres)) 147 | grained_dict[f"total_score"] = sum(sums) 148 | grained_dict[f"total_avgerage"] = sum(sums)/img_nums 149 | 150 | os.makedirs(args.output_dir,exist_ok=True) 151 | df = pd.DataFrame.from_dict(grained_dict,orient='index') 152 | df.to_excel(f"{args.output_dir}/result_test_alignment.xlsx") 153 | with open(f'{args.output_dir}/result_test_alignment.json', 'w', encoding='utf-8') as f: 154 | f.write(json.dumps(grained_dict,indent=2)) 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser() 158 | parser.add_argument("--model-path", type=str, default="sais/evalalign-v1.0-13b") 159 | parser.add_argument("--model-base", type=str, default=None) 160 | parser.add_argument("--images-dir", type=str,default=None) 161 | parser.add_argument("--output-dir", type=str,default=None) 162 | parser.add_argument("--temperature", type=float, default=0.2) 163 | parser.add_argument("--top_p", type=float, default=None) 164 | parser.add_argument("--num_beams", type=int, default=1) 165 | parser.add_argument("--max_new_tokens", type=int, default=512) 166 | args = parser.parse_args() 167 | 168 | eval_model(args) 169 | 170 | -------------------------------------------------------------------------------- /evalalign/eval/test_faithfulness.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pandas as pd 5 | import numpy as np 6 | import argparse 7 | 8 | from evalalign.constants import ( 9 | IMAGE_TOKEN_INDEX, 10 | DEFAULT_IMAGE_TOKEN, 11 | DEFAULT_IM_START_TOKEN, 12 | DEFAULT_IM_END_TOKEN, 13 | IMAGE_PLACEHOLDER, 14 | ) 15 | from evalalign.conversation import conv_templates, SeparatorStyle 16 | from evalalign.model.builder import load_pretrained_model 17 | from evalalign.utils import disable_torch_init 18 | from evalalign.mm_utils import ( 19 | process_images, 20 | tokenizer_image_token, 21 | get_model_name_from_path, 22 | ) 23 | 24 | from PIL import Image 25 | import requests 26 | from io import BytesIO 27 | import re 28 | 29 | 30 | def image_parser(args): 31 | out = args.image_file.split(args.sep) 32 | return out 33 | 34 | 35 | def load_image(image_file): 36 | if image_file.startswith("http") or image_file.startswith("https"): 37 | response = requests.get(image_file) 38 | image = Image.open(BytesIO(response.content)).convert("RGB") 39 | else: 40 | image = Image.open(image_file).convert("RGB") 41 | return image 42 | 43 | 44 | 45 | def eval_model(args): 46 | 47 | disable_torch_init() 48 | 49 | model_name = args.model_path.split("/")[-1] 50 | 51 | 52 | tokenizer, model, image_processor, context_len = load_pretrained_model( 53 | args.model_path, None, model_name 54 | ) 55 | 56 | sum_res = {} 57 | save_res = {} 58 | 59 | img_nums = len(os.listdir(args.images_dir)) 60 | for p in os.listdir(args.images_dir): 61 | imgp = os.path.join(args.images_dir,p) 62 | img_score = [] 63 | question_info = { "body": "\nAre there any issues with the [human/animals] body structure in the image, such as multiple arms, missing limbs or legs when not obscured, multiple heads, limb amputations, etc? \nOptions: -1.There are no human or animal body in the picture, 0.The body structure of the people or animals in the picture has a very grievous problem that is unbearable, 1.The body structure of the people or animals in the picture has some serious problems and is not acceptable, 2.The body structure of the people or animals in the picture has a slight problem that does not affect the senses, 3.The body structure of the people or animals in the picture is basically fine, with only a few flaws, 4.The body structure of the people or animals in the picture is completely fine and close to reality, the answer is", 64 | "Hand": "\nAre there any issues with the [human/animals] hands in the image, such as having more or less than five fingers when not obscured, broken fingers, disproportionate finger sizes, abnormal nail size proportions, etc? \nOptions: -1.No human or animal hands are shown in the picture, 0.The hand in the picture has a very grievous problem that is unbearable, 1.The hand in the picture has some serious problems and is not acceptable, 2.The hand in the picture has a slight problem that does not affect the senses, 3.The hand in the picture is basically fine, with only a few flaws, 4.The hands in the picture are completely fine and close to reality, the answer is", 65 | "face": "\nAre there any issues with [human/animals] face in the image, such as facial distortion, asymmetrical faces, abnormal facial features, unusual expressions in the eyes, etc? \nOptions: -1.There is no face of any person or animal in the picture, 0.The face of the person or animal in the picture has a very grievous problem that is unbearable, 1.The face of the person or animal in the picture has some serious problems and is not acceptable, 2.The face of the person or animal in the picture has a slight problem that does not affect the senses, 3.The face of the person or animal in the picture is basically fine, with only a few flaws, 4.The face of the person or animal in the picture is completely fine and close to reality, the answer is", 66 | "object": "\nAre there any issues or tentative errors with objects in the image that do not correspond with the real world, such as distortion of items, etc? \nOptions: 0.There are objects in the image that completely do not match the real world, which is very serious and intolerable, 1.There are objects in the image that do not match the real world, which is quite serious and unacceptable, 2.There are slightly unrealistic objects in the image that do not affect the senses, 3.There are basically no objects in the image that do not match the real world, only some flaws, 4.All objects in the image match the real world, no problem, the answer is", 67 | "common": "\nDoes the generated image contain elements that violate common sense or logical rules? \nOptions: 0.The image contains elements that violate common sense or logical rules, which is very grievous and intolerable, 1.The presence of elements in the image that seriously violate common sense or logical rules is unacceptable, 2.The image contains elements that violate common sense or logical rules, which is slightly problematic and does not affect the senses, 3.There are basically no elements in the image that violate common sense or logical rules, only some flaws, 4.There are no elements in the image that violate common sense or logical rules, and they are close to reality, the answer is"} 68 | ans_typed = {} 69 | for typed, question in question_info.items(): 70 | 71 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 72 | if IMAGE_PLACEHOLDER in question: 73 | if model.config.mm_use_im_start_end: 74 | question = re.sub(IMAGE_PLACEHOLDER, image_token_se, question) 75 | else: 76 | question = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, question) 77 | else: 78 | if model.config.mm_use_im_start_end: 79 | question = image_token_se + "\n" + question 80 | else: 81 | question = question 82 | 83 | conv_mode = "v1" 84 | if "evalalign-v1.0-34b" in model_name: 85 | conv_mode = "mpt" 86 | elif "evalalign-v1.0-13b" in model_name: 87 | conv_mode = "v1" 88 | 89 | 90 | conv = conv_templates[conv_mode].copy() 91 | conv.append_message(conv.roles[0], question) 92 | conv.append_message(conv.roles[1], None) 93 | prompt = conv.get_prompt() 94 | 95 | images = [load_image(imgp)] 96 | 97 | images_tensor = process_images( 98 | images, 99 | image_processor, 100 | model.config 101 | ).to(model.device, dtype=torch.float16) 102 | 103 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(model.device) 104 | 105 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 106 | keywords = [stop_str] 107 | with torch.inference_mode(): 108 | output_ids = model.generate( 109 | input_ids, 110 | images=[images_tensor], 111 | #image_sizes=image_sizes, 112 | do_sample=False, 113 | temperature=args.temperature, 114 | num_beams=args.num_beams, 115 | max_new_tokens=args.max_new_tokens, 116 | use_cache=True, 117 | ) 118 | 119 | output = tokenizer.decode(output_ids[0][1:]).strip() 120 | print("output",output) 121 | output = output.strip() 122 | if output.endswith(stop_str): 123 | output = output[:-len(stop_str)] 124 | output = output.strip() 125 | 126 | scores = re.findall(r"-?\d+", output) 127 | if len(scores) > 0: 128 | score = int(scores[0]) 129 | else: 130 | score = 0 131 | 132 | ans_typed[typed] = score 133 | if typed not in sum_res: 134 | sum_res[typed] = [score] 135 | else: 136 | sum_res[typed].append(score) 137 | img_score.append(score) 138 | print(imgp, img_score) 139 | 140 | grained_dict = {} 141 | sums = [] 142 | sum_1 = [] 143 | 144 | print(sum_res) 145 | nums = [] 146 | for td, tdres in sum_res.items(): 147 | num_1 = int(img_nums - sum(np.array(tdres)==-1)) 148 | nums.append(num_1) 149 | #sums.append(sum(tdres)) 150 | value_no = [v if v!=-1 else 0 for v in tdres] 151 | sum_1.append(sum(value_no)) 152 | grained_dict[f"{td}_score"] = sum(value_no) 153 | grained_dict[f"{td}_num"] = num_1 154 | grained_dict[f"{td}_average"] = sum(value_no)/num_1 if num_1>0 else 0 155 | #print(td, sum(tdres)) 156 | #print(sum(sums)) 157 | #grained_dict["res"] = sum(sums) 158 | grained_dict["total_score"] = sum(sum_1) 159 | grained_dict["num"] = sum(nums) 160 | grained_dict["avg_score"] = sum(sum_1)/sum(nums) if sum(nums)>0 else 0 161 | print(grained_dict) 162 | #save_res[classd] = grained_dict 163 | #save_dir = "/cpfs01/projects-HDD/cfff-4a8d9af84f66_HDD/yangxiaomeng/Code/AIGCbenchmark/benchmarkutils/inference_fidelity_model/prompt_pro/our_fidelity_test_result_fidelity_t2i" 164 | os.makedirs(args.output_dir,exist_ok=True) 165 | df = pd.DataFrame.from_dict(grained_dict,orient='index') 166 | df.to_excel(f"{args.output_dir}/result_test_faithfulness.xlsx") 167 | with open(f'{args.output_dir}/result_test_faithfulness.json', 'w', encoding='utf-8') as f: 168 | f.write(json.dumps(grained_dict,indent=2)) 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--model-path", type=str, default="sais/evalalign-v1.0-13b") 173 | parser.add_argument("--model-base", type=str, default=None) 174 | parser.add_argument("--images-dir", type=str,default=None) 175 | parser.add_argument("--output-dir", type=str,default=None) 176 | parser.add_argument("--temperature", type=float, default=0.2) 177 | parser.add_argument("--top_p", type=float, default=None) 178 | parser.add_argument("--num_beams", type=int, default=1) 179 | parser.add_argument("--max_new_tokens", type=int, default=512) 180 | args = parser.parse_args() 181 | 182 | eval_model(args) 183 | 184 | 185 | -------------------------------------------------------------------------------- /evalalign/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from evalalign.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = "pad" #getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | 186 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None,local_rank=None): 187 | def rank0_print(*args): 188 | if local_rank == 0: 189 | print(*args) 190 | #rank0_print(prompt) 191 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 192 | #rank0_print(prompt_chunks) 193 | def insert_separator(X, sep): 194 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 195 | 196 | input_ids = [] 197 | offset = 0 198 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 199 | offset = 1 200 | input_ids.append(prompt_chunks[0][0]) 201 | 202 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 203 | input_ids.extend(x[offset:]) 204 | 205 | if return_tensors is not None: 206 | if return_tensors == 'pt': 207 | return torch.tensor(input_ids, dtype=torch.long) 208 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 209 | return input_ids 210 | 211 | 212 | def get_model_name_from_path(model_path): 213 | model_path = model_path.strip("/") 214 | model_paths = model_path.split("/") 215 | if model_paths[-1].startswith('checkpoint-'): 216 | return model_paths[-2] + "_" + model_paths[-1] 217 | else: 218 | return model_paths[-1] 219 | 220 | class KeywordsStoppingCriteria(StoppingCriteria): 221 | def __init__(self, keywords, tokenizer, input_ids): 222 | self.keywords = keywords 223 | self.keyword_ids = [] 224 | self.max_keyword_len = 0 225 | for keyword in keywords: 226 | cur_keyword_ids = tokenizer(keyword).input_ids 227 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 228 | cur_keyword_ids = cur_keyword_ids[1:] 229 | if len(cur_keyword_ids) > self.max_keyword_len: 230 | self.max_keyword_len = len(cur_keyword_ids) 231 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 232 | self.tokenizer = tokenizer 233 | self.start_len = input_ids.shape[1] 234 | 235 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 236 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 237 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 238 | for keyword_id in self.keyword_ids: 239 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 240 | if torch.equal(truncated_output_ids, keyword_id): 241 | return True 242 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 243 | for keyword in self.keywords: 244 | if keyword in outputs: 245 | return True 246 | return False 247 | 248 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 249 | outputs = [] 250 | for i in range(output_ids.shape[0]): 251 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 252 | return all(outputs) 253 | -------------------------------------------------------------------------------- /evalalign/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | except: 6 | pass 7 | -------------------------------------------------------------------------------- /evalalign/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from evalalign import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /evalalign/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from evalalign.model import * 23 | from evalalign.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name=None, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | 49 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 50 | model = LlavaLlamaForCausalLM.from_pretrained( 51 | model_path, 52 | low_cpu_mem_usage=True, 53 | **kwargs 54 | ) 55 | 56 | image_processor = None 57 | 58 | #if 'evalalign' in model_path.lower(): 59 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 60 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 61 | if mm_use_im_patch_token: 62 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 63 | if mm_use_im_start_end: 64 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 65 | model.resize_token_embeddings(len(tokenizer)) 66 | 67 | vision_tower = model.get_vision_tower() 68 | if not vision_tower.is_loaded: 69 | vision_tower.load_model(device_map=device_map) 70 | if device_map != 'auto': 71 | vision_tower.to(device=device_map, dtype=torch.float16) 72 | image_processor = vision_tower.image_processor 73 | 74 | if hasattr(model.config, "max_sequence_length"): 75 | context_len = model.config.max_sequence_length 76 | else: 77 | context_len = 2048 78 | 79 | return tokenizer, model, image_processor, context_len 80 | -------------------------------------------------------------------------------- /evalalign/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from evalalign.model import * 10 | from evalalign.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /evalalign/model/language_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SAIS-FUXI/EvalAlign/d3fa7b49bd87e673d2a941d60c99b1d2c564205d/evalalign/model/language_model/__init__.py -------------------------------------------------------------------------------- /evalalign/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | #print("input_ids",input_ids.size()) 91 | #print("inputs_embeds",inputs_embeds.size()) 92 | return super().forward( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | position_ids=position_ids, 96 | past_key_values=past_key_values, 97 | inputs_embeds=inputs_embeds, 98 | labels=labels, 99 | use_cache=use_cache, 100 | output_attentions=output_attentions, 101 | output_hidden_states=output_hidden_states, 102 | return_dict=return_dict 103 | ) 104 | 105 | @torch.no_grad() 106 | def generate( 107 | self, 108 | inputs: Optional[torch.Tensor] = None, 109 | images: Optional[torch.Tensor] = None, 110 | image_sizes: Optional[torch.Tensor] = None, 111 | **kwargs, 112 | ) -> Union[GenerateOutput, torch.LongTensor]: 113 | position_ids = kwargs.pop("position_ids", None) 114 | attention_mask = kwargs.pop("attention_mask", None) 115 | if "inputs_embeds" in kwargs: 116 | raise NotImplementedError("`inputs_embeds` is not supported") 117 | 118 | if images is not None: 119 | ( 120 | inputs, 121 | position_ids, 122 | attention_mask, 123 | _, 124 | inputs_embeds, 125 | _ 126 | ) = self.prepare_inputs_labels_for_multimodal( 127 | inputs, 128 | position_ids, 129 | attention_mask, 130 | None, 131 | None, 132 | images, 133 | image_sizes=image_sizes 134 | ) 135 | else: 136 | raise ValueError("images is not None") 137 | inputs_embeds = self.get_model().embed_tokens(inputs) 138 | 139 | return super().generate( 140 | position_ids=position_ids, 141 | attention_mask=attention_mask, 142 | inputs_embeds=inputs_embeds, 143 | **kwargs 144 | ) 145 | 146 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 147 | inputs_embeds=None, **kwargs): 148 | images = kwargs.pop("images", None) 149 | image_sizes = kwargs.pop("image_sizes", None) 150 | inputs = super().prepare_inputs_for_generation( 151 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 152 | ) 153 | if images is not None: 154 | inputs['images'] = images 155 | if image_sizes is not None: 156 | inputs['image_sizes'] = image_sizes 157 | return inputs 158 | 159 | AutoConfig.register("llava_llama", LlavaConfig) 160 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 161 | -------------------------------------------------------------------------------- /evalalign/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /evalalign/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from evalalign.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /evalalign/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | import safetensors.torch 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from evalalign.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | from evalalign.mm_utils import get_anyres_image_grid_shape 27 | 28 | 29 | class LlavaMetaModel: 30 | 31 | def __init__(self, config): 32 | super(LlavaMetaModel, self).__init__(config) 33 | 34 | if hasattr(config, "mm_vision_tower"): 35 | self.vision_tower = build_vision_tower(config, delay_load=True) 36 | self.mm_projector = build_vision_projector(config) 37 | 38 | if 'unpad' in getattr(config, 'mm_patch_merge_type', ''): 39 | self.image_newline = nn.Parameter( 40 | torch.empty(config.hidden_size, dtype=self.dtype) 41 | ) 42 | 43 | def get_vision_tower(self): 44 | vision_tower = getattr(self, 'vision_tower', None) 45 | if type(vision_tower) is list: 46 | vision_tower = vision_tower[0] 47 | return vision_tower 48 | 49 | def initialize_vision_modules(self, model_args, fsdp=None): 50 | vision_tower = model_args.vision_tower 51 | mm_vision_select_layer = model_args.mm_vision_select_layer 52 | mm_vision_select_feature = model_args.mm_vision_select_feature 53 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 54 | mm_patch_merge_type = model_args.mm_patch_merge_type 55 | 56 | self.config.mm_vision_tower = vision_tower 57 | 58 | if self.get_vision_tower() is None: 59 | vision_tower = build_vision_tower(model_args) 60 | 61 | if fsdp is not None and len(fsdp) > 0: 62 | self.vision_tower = [vision_tower] 63 | else: 64 | self.vision_tower = vision_tower 65 | else: 66 | if fsdp is not None and len(fsdp) > 0: 67 | vision_tower = self.vision_tower[0] 68 | else: 69 | vision_tower = self.vision_tower 70 | vision_tower.load_model() 71 | 72 | self.config.use_mm_proj = True 73 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 74 | self.config.mm_hidden_size = vision_tower.hidden_size 75 | self.config.mm_vision_select_layer = mm_vision_select_layer 76 | self.config.mm_vision_select_feature = mm_vision_select_feature 77 | self.config.mm_patch_merge_type = mm_patch_merge_type 78 | 79 | if getattr(self, 'mm_projector', None) is None: 80 | self.mm_projector = build_vision_projector(self.config) 81 | 82 | if 'unpad' in mm_patch_merge_type: 83 | embed_std = 1 / torch.sqrt(torch.tensor(self.config.hidden_size, dtype=self.dtype)) 84 | self.image_newline = nn.Parameter( 85 | torch.randn(self.config.hidden_size, dtype=self.dtype) * embed_std 86 | ) 87 | # else: 88 | # # In case it is frozen by LoRA 89 | # for p in self.mm_projector.parameters(): 90 | # p.requires_grad = False 91 | # if self.config.lora_mmproject: 92 | # llava16_mm = self.config.mmproject_path #"/cpfs01/projects-HDD/cfff-4a8d9af84f66_HDD/yangxiaomeng/pretrain_models/multimodal/llava-v1.6-vicuna-13b/model-00006-of-00006.safetensors"#config.mmproject_path 93 | # llava16_mm_weights = safetensors.torch.load_file(llava16_mm, device='cpu') 94 | 95 | # mm_state_dict = {} 96 | # for llava16_name, weights in llava16_mm_weights.items(): 97 | # if "mm_projector" in llava16_name: 98 | # mm_state_dict["linear"+llava16_name.replace("model.mm_projector.","")] = weights 99 | # self.mm_projector.load_state_dict(mm_state_dict) 100 | 101 | if pretrain_mm_mlp_adapter is not None: 102 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 103 | def get_w(weights, keyword): 104 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 105 | 106 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 107 | 108 | 109 | def unpad_image(tensor, original_size): 110 | """ 111 | Unpads a PyTorch tensor of a padded and resized image. 112 | 113 | Args: 114 | tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format. 115 | original_size (tuple): The original size of PIL image (width, height). 116 | 117 | Returns: 118 | torch.Tensor: The unpadded image tensor. 119 | """ 120 | original_width, original_height = original_size 121 | current_height, current_width = tensor.shape[1:] 122 | 123 | original_aspect_ratio = original_width / original_height 124 | current_aspect_ratio = current_width / current_height 125 | 126 | if original_aspect_ratio > current_aspect_ratio: 127 | scale_factor = current_width / original_width 128 | new_height = int(original_height * scale_factor) 129 | padding = (current_height - new_height) // 2 130 | unpadded_tensor = tensor[:, padding:current_height - padding, :] 131 | else: 132 | scale_factor = current_height / original_height 133 | new_width = int(original_width * scale_factor) 134 | padding = (current_width - new_width) // 2 135 | unpadded_tensor = tensor[:, :, padding:current_width - padding] 136 | 137 | return unpadded_tensor 138 | 139 | 140 | class LlavaMetaForCausalLM(ABC): 141 | 142 | @abstractmethod 143 | def get_model(self): 144 | pass 145 | 146 | def get_vision_tower(self): 147 | return self.get_model().get_vision_tower() 148 | 149 | def encode_images(self, images): 150 | image_features = self.get_model().get_vision_tower()(images) 151 | image_features = self.get_model().mm_projector(image_features) 152 | return image_features 153 | 154 | def prepare_inputs_labels_for_multimodal( 155 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 156 | images, image_sizes=None 157 | ): 158 | vision_tower = self.get_vision_tower() 159 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 160 | #print(vision_tower,images,input_ids) 161 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 162 | 163 | if type(images) is list or images.ndim == 5: 164 | if type(images) is list: 165 | images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images] 166 | concat_images = torch.cat([image for image in images], dim=0) 167 | #print("concat_images",concat_images.size()) 168 | image_features = self.encode_images(concat_images) 169 | split_sizes = [image.shape[0] for image in images] 170 | #print(split_sizes) 171 | #print("image_features_0",image_features.size()) 172 | image_features = torch.split(image_features, split_sizes, dim=0) 173 | #print("image_features_1",len(image_features)) 174 | mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type', 'flat') 175 | image_aspect_ratio = getattr(self.config, 'image_aspect_ratio', 'square') 176 | #print("image_features_1",image_features[0].size()) 177 | #raise() 178 | if mm_patch_merge_type == 'flat': 179 | image_features = [x.flatten(0, 1) for x in image_features] 180 | elif mm_patch_merge_type.startswith('spatial'): 181 | new_image_features = [] 182 | for image_idx, image_feature in enumerate(image_features): 183 | if image_feature.shape[0] > 1: 184 | base_image_feature = image_feature[0] 185 | image_feature = image_feature[1:] 186 | height = width = self.get_vision_tower().num_patches_per_side 187 | assert height * width == base_image_feature.shape[0] 188 | if image_aspect_ratio == 'anyres': 189 | print("image_features_2",image_feature.size()) 190 | num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[image_idx], self.config.image_grid_pinpoints, self.get_vision_tower().config.image_size) 191 | image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) 192 | else: 193 | raise NotImplementedError 194 | print("image_features_3",image_feature.size()) 195 | if 'unpad' in mm_patch_merge_type: 196 | image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() 197 | print("image_features_4",image_feature.size()) 198 | image_feature = image_feature.flatten(1, 2).flatten(2, 3) 199 | print("image_features_5",image_feature.size()) 200 | image_feature = unpad_image(image_feature, image_sizes[image_idx]) 201 | print("image_features_6",image_feature.size()) 202 | image_feature = torch.cat(( 203 | image_feature, 204 | self.model.image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.device) 205 | ), dim=-1) 206 | print("image_features_7",image_feature.size()) 207 | image_feature = image_feature.flatten(1, 2).transpose(0, 1) 208 | print("image_features_8",image_feature.size()) 209 | else: 210 | image_feature = image_feature.permute(0, 2, 1, 3, 4).contiguous() 211 | image_feature = image_feature.flatten(0, 3) 212 | image_feature = torch.cat((base_image_feature, image_feature), dim=0) 213 | print("image_features_9",image_feature.size()) 214 | raise() 215 | else: 216 | #print("++++++++++++++++++++++") 217 | #raise() 218 | image_feature = image_feature[0] 219 | if 'unpad' in mm_patch_merge_type: 220 | image_feature = torch.cat(( 221 | image_feature, 222 | self.model.image_newline[None].to(image_feature.device) 223 | ), dim=0) 224 | new_image_features.append(image_feature) 225 | image_features = new_image_features 226 | else: 227 | raise ValueError(f"Unexpected mm_patch_merge_type: {self.config.mm_patch_merge_type}") 228 | else: 229 | #print("images",images.size()) 230 | image_features = self.encode_images(images).to(self.device) 231 | #print("image_features",image_features.size()) 232 | 233 | # TODO: image start / end is not implemented here to support pretraining. 234 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 235 | raise NotImplementedError 236 | 237 | # Let's just add dummy tensors if they do not exist, 238 | # it is a headache to deal with None all the time. 239 | # But it is not ideal, and if you have a better idea, 240 | # please open an issue / submit a PR, thanks. 241 | _labels = labels 242 | _position_ids = position_ids 243 | _attention_mask = attention_mask 244 | if attention_mask is None: 245 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 246 | else: 247 | attention_mask = attention_mask.bool() 248 | if position_ids is None: 249 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 250 | if labels is None: 251 | labels = torch.full_like(input_ids, IGNORE_INDEX) 252 | 253 | # remove the padding using attention_mask -- FIXME 254 | _input_ids = input_ids 255 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 256 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 257 | 258 | new_input_embeds = [] 259 | new_labels = [] 260 | cur_image_idx = 0 261 | #print("input_ids",len(input_ids)) 262 | for batch_idx, cur_input_ids in enumerate(input_ids): 263 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 264 | if num_images == 0: 265 | cur_image_features = image_features[cur_image_idx] 266 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 267 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 268 | new_input_embeds.append(cur_input_embeds) 269 | new_labels.append(labels[batch_idx]) 270 | cur_image_idx += 1 271 | continue 272 | 273 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]] 274 | cur_input_ids_noim = [] 275 | cur_labels = labels[batch_idx] 276 | cur_labels_noim = [] 277 | for i in range(len(image_token_indices) - 1): 278 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i]+1:image_token_indices[i+1]]) 279 | cur_labels_noim.append(cur_labels[image_token_indices[i]+1:image_token_indices[i+1]]) 280 | split_sizes = [x.shape[0] for x in cur_labels_noim] 281 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 282 | #print("cur_input_embeds",cur_input_embeds.size()) 283 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 284 | #print("cur_input_embeds_no_im",len(cur_input_embeds_no_im)) 285 | cur_new_input_embeds = [] 286 | cur_new_labels = [] 287 | 288 | for i in range(num_images + 1): 289 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 290 | cur_new_labels.append(cur_labels_noim[i]) 291 | if i < num_images: 292 | cur_image_features = image_features[cur_image_idx] 293 | cur_image_idx += 1 294 | cur_new_input_embeds.append(cur_image_features) 295 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 296 | 297 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 298 | 299 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 300 | cur_new_labels = torch.cat(cur_new_labels) 301 | 302 | #print("cur_new_input_embeds",cur_new_input_embeds.size()) 303 | #print("cur_new_labels",cur_new_labels.size()) 304 | 305 | new_input_embeds.append(cur_new_input_embeds) 306 | new_labels.append(cur_new_labels) 307 | #raise() 308 | 309 | # Truncate sequences to max length as image embeddings can make the sequence longer 310 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 311 | if tokenizer_model_max_length is not None: 312 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 313 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 314 | 315 | # Combine them 316 | max_len = max(x.shape[0] for x in new_input_embeds) 317 | #print("max_len",max_len) 318 | batch_size = len(new_input_embeds) 319 | 320 | new_input_embeds_padded = [] 321 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 322 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 323 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 324 | 325 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 326 | cur_len = cur_new_embed.shape[0] 327 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 328 | new_input_embeds_padded.append(torch.cat(( 329 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 330 | cur_new_embed 331 | ), dim=0)) 332 | if cur_len > 0: 333 | new_labels_padded[i, -cur_len:] = cur_new_labels 334 | attention_mask[i, -cur_len:] = True 335 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 336 | else: 337 | new_input_embeds_padded.append(torch.cat(( 338 | cur_new_embed, 339 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 340 | ), dim=0)) 341 | if cur_len > 0: 342 | new_labels_padded[i, :cur_len] = cur_new_labels 343 | attention_mask[i, :cur_len] = True 344 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 345 | 346 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 347 | 348 | if _labels is None: 349 | new_labels = None 350 | else: 351 | new_labels = new_labels_padded 352 | 353 | if _attention_mask is None: 354 | attention_mask = None 355 | else: 356 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 357 | 358 | if _position_ids is None: 359 | position_ids = None 360 | 361 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 362 | 363 | def initialize_vision_tokenizer(self, model_args, tokenizer): 364 | if model_args.mm_use_im_patch_token: 365 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 366 | self.resize_token_embeddings(len(tokenizer)) 367 | 368 | if model_args.mm_use_im_start_end: 369 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 370 | self.resize_token_embeddings(len(tokenizer)) 371 | 372 | if num_new_tokens > 0: 373 | input_embeddings = self.get_input_embeddings().weight.data 374 | output_embeddings = self.get_output_embeddings().weight.data 375 | 376 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 377 | dim=0, keepdim=True) 378 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 379 | dim=0, keepdim=True) 380 | 381 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 382 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 383 | 384 | if model_args.tune_mm_mlp_adapter: 385 | for p in self.get_input_embeddings().parameters(): 386 | p.requires_grad = True 387 | for p in self.get_output_embeddings().parameters(): 388 | p.requires_grad = False 389 | 390 | if model_args.pretrain_mm_mlp_adapter: 391 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 392 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 393 | assert num_new_tokens == 2 394 | if input_embeddings.shape == embed_tokens_weight.shape: 395 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 396 | elif embed_tokens_weight.shape[0] == num_new_tokens: 397 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 398 | else: 399 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 400 | elif model_args.mm_use_im_patch_token: 401 | if model_args.tune_mm_mlp_adapter: 402 | for p in self.get_input_embeddings().parameters(): 403 | p.requires_grad = False 404 | for p in self.get_output_embeddings().parameters(): 405 | p.requires_grad = False 406 | -------------------------------------------------------------------------------- /evalalign/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from evalalign.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /evalalign/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 'use_s2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | print("---------CLIPVisionTowerS2---------") 12 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 13 | else: 14 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 15 | 16 | raise ValueError(f'Unknown vision tower: {vision_tower}') 17 | -------------------------------------------------------------------------------- /evalalign/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 31 | self.vision_tower.requires_grad_(False) 32 | 33 | self.is_loaded = True 34 | 35 | def feature_select(self, image_forward_outs): 36 | image_features = image_forward_outs.hidden_states[self.select_layer] 37 | if self.select_feature == 'patch': 38 | image_features = image_features[:, 1:] 39 | elif self.select_feature == 'cls_patch': 40 | image_features = image_features 41 | else: 42 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 43 | return image_features 44 | 45 | @torch.no_grad() 46 | def forward(self, images): 47 | if type(images) is list: 48 | image_features = [] 49 | for image in images: 50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 51 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 52 | image_features.append(image_feature) 53 | else: 54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 55 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 56 | 57 | return image_features 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.device 70 | 71 | @property 72 | def config(self): 73 | if self.is_loaded: 74 | return self.vision_tower.config 75 | else: 76 | return self.cfg_only 77 | 78 | @property 79 | def hidden_size(self): 80 | return self.config.hidden_size 81 | 82 | @property 83 | def num_patches_per_side(self): 84 | return self.config.image_size // self.config.patch_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | 90 | 91 | 92 | class CLIPVisionTowerS2(CLIPVisionTower): 93 | def __init__(self, vision_tower, args, delay_load=False): 94 | super().__init__(vision_tower, args, delay_load) 95 | 96 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 97 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 98 | self.s2_scales.sort() 99 | self.s2_split_size = self.s2_scales[0] # 336 100 | self.s2_image_size = self.s2_scales[-1] #1008 101 | 102 | self.image_processor.size['shortest_edge'] = self.s2_image_size 103 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 104 | 105 | try: 106 | from evalalign.model.s2wrapper import forward as multiscale_forward 107 | except ImportError: 108 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') 109 | self.multiscale_forward = multiscale_forward 110 | 111 | if not delay_load: 112 | self.load_model() 113 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 114 | self.load_model() 115 | else: 116 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 117 | 118 | 119 | # change resize/crop size in preprocessing to the largest image size in s2_scale 120 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 121 | self.image_processor.size['shortest_edge'] = self.s2_image_size 122 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 123 | 124 | def load_model(self, device_map=None): 125 | if self.is_loaded: 126 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 127 | return 128 | 129 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 130 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 131 | self.vision_tower.requires_grad_(False) 132 | 133 | self.is_loaded = True 134 | 135 | @torch.no_grad() 136 | def forward_feature(self, images): 137 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 138 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 139 | return image_features 140 | 141 | @torch.no_grad() 142 | def forward(self, images): 143 | if type(images) is list: 144 | image_features = [] 145 | for image in images: 146 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 147 | image_features.append(image_feature) 148 | else: 149 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 150 | 151 | return image_features 152 | 153 | @property 154 | def hidden_size(self): 155 | return self.config.hidden_size * len(self.s2_scales) 156 | -------------------------------------------------------------------------------- /evalalign/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | class mmprojector(nn.Module): 33 | def __init__(self, config): 34 | super().__init__() 35 | 36 | self.linear0 = nn.Linear(config.mm_hidden_size, config.hidden_size) 37 | self.gelu = nn.GELU() 38 | self.linear2 = nn.Linear(config.hidden_size, config.hidden_size) 39 | 40 | def forward(self, x): 41 | return self.linear2(self.gelu(self.linear0(x))) 42 | 43 | def build_vision_projector(config, delay_load=False, **kwargs): 44 | projector_type = getattr(config, 'mm_projector_type', 'linear') 45 | 46 | if projector_type == 'linear': 47 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 48 | 49 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 50 | if mlp_gelu_match: 51 | mlp_depth = int(mlp_gelu_match.group(1)) 52 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 53 | for _ in range(1, mlp_depth): 54 | modules.append(nn.GELU()) 55 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 56 | return nn.Sequential(*modules) 57 | 58 | if projector_type == 'identity': 59 | return IdentityMap() 60 | 61 | raise ValueError(f'Unknown projector type: {projector_type}') 62 | -------------------------------------------------------------------------------- /evalalign/model/s2wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import * 2 | from .utils import * -------------------------------------------------------------------------------- /evalalign/model/s2wrapper/core.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from .utils import split_chessboard, merge_chessboard, batched_forward 7 | 8 | def forward(model, input, scales=None, img_sizes=None, max_split_size=None, resize_output_to_idx=0, num_prefix_token=0, 9 | output_shape='bnc', split_forward=False): 10 | 11 | assert input.dim() == 4, "Input image must be in the shape of BxCxHxW." 12 | assert input.shape[2] == input.shape[3], "Currently only square images are supported." 13 | assert output_shape in ['bnc', 'bchw'], "Output shape should be either BxNxC (e.g., ViT) or BxCxHxW (e.g., ConvNet)." 14 | assert output_shape == 'bnc' or num_prefix_token == 0, "For ConvNet there shouldn't be any prefix token." 15 | 16 | b, c, input_size, _ = input.shape 17 | 18 | # image size for each scale 19 | assert scales is not None or img_sizes is not None, "Please assign either scales or img_sizes." 20 | img_sizes = img_sizes or [int(input_size * scale) for scale in scales] 21 | 22 | # prepare multiscale inputs 23 | max_split_size = max_split_size or input_size # The maximum size of each split of image. Set as the input size by default 24 | num_splits = [math.ceil(size / max_split_size) for size in img_sizes] # number of splits each scale 25 | input_multiscale = [] 26 | for size, num_split in zip(img_sizes, num_splits): 27 | x = F.interpolate(input.to(torch.float32), size=size, mode='bicubic').to(input.dtype) 28 | #print("x",x.size()) 29 | x = split_chessboard(x, num_split=num_split) 30 | #print("x_1",x.size()) 31 | input_multiscale.append(x) 32 | 33 | # run feedforward on each scale 34 | outs_multiscale = [batched_forward(model, x, b) if split_forward else model(x) for x in input_multiscale] 35 | #print("outs_multiscale",len(outs_multiscale),outs_multiscale[0].size(),outs_multiscale[1].size(),outs_multiscale[2].size()) 36 | if num_prefix_token > 0: 37 | outs_prefix_multiscale = [out[:, :num_prefix_token] for out in outs_multiscale] 38 | outs_multiscale = [out[:, num_prefix_token:] for out in outs_multiscale] 39 | if output_shape == 'bnc': 40 | outs_multiscale = [rearrange(out, 'b (h w) c -> b c h w', h=int(out.shape[1] ** 0.5), w=int(out.shape[1] ** 0.5)) 41 | for out in outs_multiscale] 42 | 43 | #print("outs_multiscale_1",len(outs_multiscale),outs_multiscale[0].size(),outs_multiscale[1].size(),outs_multiscale[2].size()) 44 | # merge outputs of different splits for each scale separately 45 | outs_multiscale = [merge_chessboard(out, num_split=num_split) for num_split, out in zip(num_splits, outs_multiscale)] 46 | #print("outs_multiscale_2",len(outs_multiscale),outs_multiscale[0].size(),outs_multiscale[1].size(),outs_multiscale[2].size()) 47 | # interpolate outputs from different scales and concat together 48 | output_size = outs_multiscale[resize_output_to_idx].shape[-2] 49 | #print("output_size",output_size) 50 | out = torch.cat([F.interpolate(outs_multiscale[i].to(torch.float32), size=output_size, 51 | mode='area').to(outs_multiscale[i].dtype).unsqueeze(1) 52 | for i in range(len(outs_multiscale))], dim=1) 53 | #print("out",out.size()) 54 | #raise() 55 | if output_shape == 'bnc': 56 | out = rearrange(out, 'b k c h w -> b (k h w) c') 57 | #out = rearrange(out, 'b c h w -> b (h w) c') 58 | if num_prefix_token > 0: 59 | # take the mean of prefix tokens from different splits for each scale 60 | outs_prefix_multiscale = [torch.stack(out.split(b, dim=0), dim=0).mean(dim=0) for out in outs_prefix_multiscale] 61 | out_prefix_multiscale = torch.cat(outs_prefix_multiscale, dim=-1) 62 | out = torch.cat([out_prefix_multiscale, out], dim=1) 63 | 64 | return out 65 | -------------------------------------------------------------------------------- /evalalign/model/s2wrapper/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def split_chessboard(x, num_split): 4 | """ 5 | x: b * c * h * w 6 | Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension 7 | """ 8 | B, C, H, W = x.shape 9 | assert H % num_split == 0 and W % num_split == 0 10 | h, w = H // num_split, W // num_split 11 | x_split = torch.cat([x[:, :, i*h:(i+1)*h, j*w:(j+1)*w] for i in range(num_split) for j in range(num_split)], dim=0) 12 | return x_split 13 | 14 | def merge_chessboard(x, num_split): 15 | """ 16 | x: b * c * h * w 17 | Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square. 18 | (inverse of split_chessboard) 19 | """ 20 | B, C, H, W = x.shape 21 | assert B % (num_split**2) == 0 22 | b = B // (num_split**2) 23 | x_merge = torch.cat([torch.cat([x[(i*num_split + j)*b:(i*num_split + j + 1)*b] for j in range(num_split)], dim=-1) 24 | for i in range(num_split)], dim=-2) 25 | return x_merge 26 | 27 | def batched_forward(model, x, batch_size=-1): 28 | if batch_size == -1: 29 | return model(x) 30 | else: 31 | x_batched = x.split(batch_size) 32 | outs = [model(x) for x in x_batched] 33 | return torch.cat(outs, dim=0) 34 | 35 | -------------------------------------------------------------------------------- /evalalign/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /evalalign/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from evalalign.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "evalalign" 7 | version = "1.0.0" 8 | description = "EVALALIGN: Supervised Fine-Tuning Multimodal LLMs with Human-Aligned Data for Evaluating Text-to-Image Models" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | 16 | dependencies = [ 17 | "torch==2.1.2", "torchvision==0.16.2", "openpyxl", 18 | "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 19 | "accelerate==0.21.0", "peft", "bitsandbytes", 20 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 21 | "gradio==4.16.0", "gradio_client==0.8.1", 22 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 23 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 28 | build = ["build", "twine"] 29 | 30 | [project.urls] 31 | "Bug Tracker" = "https://github.com/SAIS-FUXI/EvalAlign/issues" 32 | 33 | [tool.setuptools.packages.find] 34 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | -------------------------------------------------------------------------------- /scripts/inference_alignment.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python evalalign/eval/test_alignment.py \ 2 | --model-path Fudan-FUXI/evalalign-v1.0-13b \ 3 | --images-dir ./IF-I-XL-v1.0 \ 4 | --output-dir ./results_alignment -------------------------------------------------------------------------------- /scripts/inference_faithfulness.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python evalalign/eval/test_faithfulness.py \ 2 | --model-path Fudan-FUXI/evalalign-v1.0-13b \ 3 | --images-dir ./PixArt-XL-2-1024-MS \ 4 | --output-dir ./results_faithfulness --------------------------------------------------------------------------------