├── .gitignore ├── LICENSE ├── README.md ├── README_zh.md ├── asset ├── WeChat.png └── method.png ├── docs ├── Data.md ├── Evaluation.md └── Model_zoo.md ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── data │ ├── __init__.py │ └── data.py ├── eval │ ├── evaluate_grounding.py │ └── run.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── language_model │ │ └── llava_llama.py │ ├── llava_arch.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip_encoder.py │ │ ├── convnext_encoder.py │ │ ├── lknet_encoder.py │ │ ├── siglip_encoder.py │ │ └── unireplknet │ │ │ ├── __init__.py │ │ │ └── unireplknet_encoder.py │ └── multimodal_projector │ │ └── builder.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llava_trainer.py │ └── train.py └── utils.py ├── pyproject.toml └── scripts ├── eval-lmms.sh ├── evaluation.sh ├── refcoco.sh ├── stage_1.sh ├── stage_2.sh ├── stage_3.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # Images 156 | images/ 157 | 158 | *.tar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

ConvLLaVA: Hierarchical Backbones as Visual Encoder for Large Multimodal Models

4 | 5 | [Chunjiang Ge](https://john-ge.github.io/), [Sijie Cheng](https://adacheng.github.io/), Ziming Wang, Jiale Yuan, Yuan Gao 6 | 7 | Jun Song, Shiji Song, [Gao Huang](https://www.gaohuang.net/), Bo Zheng 8 | 9 |
10 | 11 |

12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | GitHub stars 32 | 33 |

34 | 35 | [ English | 中文 ] 36 | 37 | ## Abstract 38 | 39 | High-resolution Large Multimodel Models (LMM) encounter the challenges of excessive visual tokens and quadratic visual complexity. Current high-resolution LMMs address the quadratic complexity while still generating excessive visual tokens. **However, the redundancy in visual tokens is the key problem as it leads to more substantial compute.** To mitigate this, we propose ConvLLaVA, which employs ConvNeXt, a hierarchical backbone, as the visual encoder of LMM to replace Vision Transformer (ViT). **ConvLLaVA compresses high-resolution images into information-rich visual features, effectively avoiding the generation of excessive visual tokens.** To enhance the capabilities of ConvLLaVA, we propose two critical optimizations. 40 | 41 | - Since the low-resolution pretrained ConvNeXt underperforms when directly applied on high resolution, we **update** it to merge the gap. 42 | - Furthermore, since ConvNeXt's original compression ratio is insufficient for much higher resolution inputs, we train a **successive stage** to further compress the visual tokens, effectively reducing redundancy. 43 | 44 | **These optimizations enable ConvLLaVA to support inputs of 1536x1536 resolution while generating only 576 visual tokens, accommodating images of arbitrary aspect ratios.** [Experimental results](#model-zoo) demonstrate that our method achieves competitive performance with state-of-the-art models on mainstream benchmarks. 45 | 46 |
47 | 48 |
49 |
50 |
Comparison between LLaVA and ConvLLaVA.
51 |
52 | 53 | ## Release :loudspeaker: 54 | 55 | - **2024/05/25**: Checkpoints are released. 56 | - **2024/04/17**: Our code is released. 57 | 58 | [![Collaborations](https://img.shields.io/badge/Welcome-Collaborations-b31b1b.svg)](mailto:gecj20@mails.tsinghua.edu.cn) 59 | If you are interested in Large Multimodal Models or you have great ideas, please feel free to email with me: [Chunjiang Ge](mailto:gecj20@mails.tsinghua.edu.cn). 60 | 61 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 62 | **Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations. 63 | 64 | ## Contents 65 | - [Abstract](#abstract) 66 | - [Release :loudspeaker:](#release-loudspeaker) 67 | - [Contents](#contents) 68 | - [TODO](#todo) 69 | - [Install](#install) 70 | - [Model Zoo](#model-zoo) 71 | - [Dataset](#dataset) 72 | - [Train](#train) 73 | - [Evaluation](#evaluation) 74 | - [Citation](#citation) 75 | - [Acknowledgement](#acknowledgement) 76 | 77 | ## TODO 78 | 79 | - [ ] Add [LMMs-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) supports. 80 | - [ ] Add [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) supports. 81 | - [ ] Add [xtuner](https://github.com/InternLM/xtuner) supports. 82 | - [x] Release weights. 83 | - [ ] Release inference code. 84 | 85 | ## Install 86 | 87 | 1. Clone this repository and navigate to ConvLLaVA folder 88 | ```bash 89 | git clone https://github.com/alibaba/conv-llava 90 | cd conv-llava 91 | ``` 92 | 93 | 1. Install Package 94 | ```bash 95 | conda create -n convllava python=3.11 -y 96 | conda activate convllava 97 | pip install --upgrade pip # enable PEP 660 support 98 | pip install -e . 99 | ``` 100 | 101 | 3. Install additional packages for training cases 102 | ```bash 103 | pip install -e ".[train]" 104 | pip install flash-attn --no-build-isolation 105 | ``` 106 | 107 | ## Model Zoo 108 | 109 | The performance on mainstream benchmarks are shown below: 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 |
MethodResolutionVisual TokensLLMMMEMMBSEEDRealWorldQAMMMUMMVetTextDocPOPE
ConvLLaVA7681447B15416868.855.936.344.859.144.887.3
ConvLLaVA10242567B155368.869.358.835.144.462.548.587.7
ConvLLaVA15365767B157568.770.259.935.845.965.85987.3
174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 |
MethodResolutionVisual TokensLLMRefCOCORefCOCO+RefCOCOgAvg
valtest-Atest-Bvaltest-Atest-Bvaltest
ConvLLaVA7681447B84.589.079.277.784.969.779.879.780.6
ConvLLaVA10242567B85.589.678.879.386.170.380.681.281.4
ConvLLaVA15365767B86.590.680.580.086.871.582.082.482.3
243 | 244 | Please check out our [Model Zoo](https://github.com/alibaba/conv-llava/blob/main/docs/Model_zoo.md) for all public ConvLLaVA checkpoints, and the instructions of how to use the weights. 245 | 246 | ## Dataset 247 | 248 | Data we use is introduded in [Data.md](https://github.com/alibaba/conv-llava/blob/main/docs/Data.md). 249 | 250 | ## Train 251 | 252 | We use the following hyperparameters for training ConvLLaVA. 253 | 254 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 | 255 | | --------------- | ------- | ------- | ------- | 256 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 | 257 | | Batch Size | 256 | 256 | 128 | 258 | | Epochs | 1 | 1 | 1 | 259 | | Warmup Ratio | 0.03 | 0.03 | 0.03 | 260 | | Weight Decay | 0 | 0 | 0 | 261 | | Optimizer | AdamW | AdamW | AdamW | 262 | 263 | The training scripts are in the [scripts](https://github.com/alibaba/conv-llava/tree/main/scripts): 264 | 265 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh) 266 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh) 267 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh) 268 | 269 | ## Evaluation 270 | 271 | We support [VLMEVALKIT](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) to evaluate our model now. See [Evaluation.md](https://github.com/alibaba/conv-llava/blob/main/docs/Evaluation.md) for more details. 272 | 273 | ## Citation 274 | 275 | If you find LLaVA useful for your research and applications, please cite using this BibTeX: 276 | 277 | ```bibtex 278 | @misc{ge2024convllava, 279 | title={ConvLLaVA: Hierarchical Backbones as Visual 280 | Encoder for Large Multimodal Models}, 281 | author={Chunjiang Ge, Sijie Cheng, Ziming Wang, Jiale Yuan, Yuan Gao, Jun Song, Shiji Song, Gao Huang, Bo Zheng}, 282 | archivePrefix={arXiv}, 283 | primaryClass={cs.CV} 284 | year={2024} 285 | eprint={2045.15738}, 286 | } 287 | ``` 288 | 289 | ## Acknowledgement 290 | 291 | - [Vicuna](https://github.com/lm-sys/FastChat): the codebase LLaVA built upon, and our base model Vicuna-13B that has the amazing language capabilities! 292 | - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. 293 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

ConvLLaVA: Hierarchical Backbones as Visual Encoder for Large Multimodal Models

4 | 5 | [Chunjiang Ge](https://john-ge.github.io/), [Sijie Cheng](https://adacheng.github.io/), Ziming Wang, Jiale Yuan, Yuan Gao 6 | 7 | Jun Song, Shiji Song, [Gao Huang](https://www.gaohuang.net/), Bo Zheng 8 | 9 |
10 | 11 |

12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | GitHub stars 32 | 33 |

34 | 35 | [ English | 中文 ] 36 | 37 | ## 摘要 :bulb: 38 | 39 | 高分辨率多模态大模型(LMM)面临视觉token过多和视觉平方复杂度的挑战。当前的高分辨率LMM通常能够解决二次复杂度问题,却会生成过量的视觉token。**然而,过多的视觉token才是更关键的问题,因为它会导致更显著的计算开销。** 为了解决这个问题,我们提出了ConvLLaVA,它采用层次化的主干网络ConvNeXt作为LMM的视觉编码器,以替代Vision Transformer(ViT)。**ConvLLaVA将高分辨率图像压缩成富含信息的视觉特征,有效避免了生成过多的视觉token。** 为了增强ConvLLaVA的能力,我们提出了两项关键优化措施。 40 | 41 | - 由于低分辨率预训练的ConvNeXt在直接应用于高分辨率时表现不佳,**我们更新它以弥合这一差距。** 42 | - 此外,由于ConvNeXt原有的压缩比对于更高分辨率的输入来说不足,**我们训练了一个新的stage,以进一步压缩视觉token**,有效减少冗余。 43 | 44 | **这些优化使得ConvLLaVA能够支持1536x1536分辨率的输入,同时仅生成576个视觉token,并适应任意宽高比的图像。** [实验结果](#model-zoo)显示,我们的方法在主流基准测试上与最先进的模型相比取得了竞争性的性能。 45 | 46 |
47 | 48 |
49 |
50 |
LLaVA 和 ConvLLaVA 结构上的对比
51 |
52 | 53 | 54 | [![Collaborations](https://img.shields.io/badge/Welcome-Collaborations-b31b1b.svg)](mailto:gecj20@mails.tsinghua.edu.cn) 55 | 如果你对多模态大模型感兴趣,或者你有很好的想法,请你联系我:[Chunjiang Ge](mailto:gecj20@mails.tsinghua.edu.cn). 56 | 57 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 58 | **Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations. 59 | 60 | ## 内容 61 | - [摘要 :bulb:](#摘要-bulb) 62 | - [内容](#内容) 63 | - [计划](#计划) 64 | - [安装](#安装) 65 | - [模型库](#模型库) 66 | - [数据集](#数据集) 67 | - [训练](#训练) 68 | - [评测](#评测) 69 | - [引用](#引用) 70 | - [致谢](#致谢) 71 | 72 | ## 计划 73 | 74 | - [ ] Add [LMMs-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) supports. 75 | - [ ] Add [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) supports. 76 | - [ ] Add [xtuner](https://github.com/InternLM/xtuner) supports. 77 | - [x] Release weights. 78 | - [ ] Release inference code. 79 | 80 | ## 安装 81 | 82 | 1. Clone this repository and navigate to ConvLLaVA folder 83 | ```bash 84 | git clone https://github.com/alibaba/conv-llava 85 | cd conv-llava 86 | ``` 87 | 88 | 1. Install Package 89 | ```bash 90 | conda create -n convllava python=3.11 -y 91 | conda activate convllava 92 | pip install --upgrade pip # enable PEP 660 support 93 | pip install -e . 94 | ``` 95 | 96 | 3. Install additional packages for training cases 97 | ```bash 98 | pip install -e ".[train]" 99 | pip install flash-attn --no-build-isolation 100 | ``` 101 | 102 | ## 模型库 103 | 104 | 我们的模型的在一些测试基准上的性能如下: 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 |
MethodResolutionVisual TokensLLMMMEMMBSEEDRealWorldQAMMMUMMVetTextDocPOPE
ConvLLaVA7681447B15416868.855.936.344.859.144.887.3
ConvLLaVA10242567B155368.869.358.835.144.462.548.587.7
ConvLLaVA15365767B157568.770.259.935.845.965.85987.3
169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 |
MethodResolutionVisual TokensLLMRefCOCORefCOCO+RefCOCOgAvg
valtest-Atest-Bvaltest-Atest-Bvaltest
ConvLLaVA7681447B84.589.079.277.784.969.779.879.780.6
ConvLLaVA10242567B85.589.678.879.386.170.380.681.281.4
ConvLLaVA15365767B86.590.680.580.086.871.582.082.482.3
238 | 239 | 我们的 [Model Zoo](https://github.com/alibaba/conv-llava/blob/main/docs/Model_zoo.md) 中包含了主要的权重和下载方式,并有说明如何使用这些权重。 240 | 241 | ## 数据集 242 | 243 | 我们实验用到的数据在 [Data.md](https://github.com/alibaba/conv-llava/blob/main/docs/Data.md) 中有介绍。 244 | 245 | ## 训练 246 | 247 | 训练的超参数如下: 248 | 249 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 | 250 | | --------------- | ------- | ------- | ------- | 251 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 | 252 | | Batch Size | 256 | 256 | 128 | 253 | | Epochs | 1 | 1 | 1 | 254 | | Warmup Ratio | 0.03 | 0.03 | 0.03 | 255 | | Weight Decay | 0 | 0 | 0 | 256 | | Optimizer | AdamW | AdamW | AdamW | 257 | 258 | 训练脚本在文件夹 [scripts](https://github.com/alibaba/conv-llava/tree/main/scripts) 中: 259 | 260 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh) 261 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh) 262 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh) 263 | 264 | ## 评测 265 | 266 | 我们目前支持 [VLMEVALKIT](https://github.com/open-compass/VLMEvalKit) 和 [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) 来测试模型。请看 [Evaluation.md](https://github.com/alibaba/conv-llava/blob/main/docs/Evaluation.md) 了解更多细节. 267 | 268 | ## 引用 269 | 270 | 如果你认为我们的工作有所帮助,请你通过下面的 BibTeX 来引用我们的工作: 271 | 272 | ```bibtex 273 | @misc{ge2024convllava, 274 | title={ConvLLaVA: Hierarchical Backbones as Visual 275 | Encoder for Large Multimodal Models}, 276 | author={Chunjiang Ge, Sijie Cheng, Ziming Wang, Jiale Yuan, Yuan Gao, Jun Song, Shiji Song, Gao Huang, Bo Zheng}, 277 | archivePrefix={arXiv}, 278 | primaryClass={cs.CV} 279 | year={2024} 280 | eprint={2045.15738}, 281 | } 282 | ``` 283 | 284 | ## 致谢 285 | 286 | - [Vicuna](https://github.com/lm-sys/FastChat): the codebase LLaVA built upon, and our base model Vicuna-13B that has the amazing language capabilities! 287 | - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. 288 | -------------------------------------------------------------------------------- /asset/WeChat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/asset/WeChat.png -------------------------------------------------------------------------------- /asset/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/asset/method.png -------------------------------------------------------------------------------- /docs/Data.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | We use the following hyperparameters for training ConvLLaVA. 4 | 5 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 | 6 | | --------------- | ------- | ------- | ------- | 7 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 | 8 | | Batch Size | 256 | 256 | 128 | 9 | | Epochs | 1 | 1 | 1 | 10 | | Warmup Ratio | 0.03 | 0.03 | 0.03 | 11 | | Weight Decay | 0 | 0 | 0 | 12 | | Optimizer | AdamW | AdamW | AdamW | 13 | 14 | ## Projector Initialzation 15 | 16 | We use captions from ShareGPT4V-PT, ShareGPT4V, ALLAVA. 17 | 18 | ## Vision Language Pretraining 19 | 20 | We use ShareGPT4V-PT, ShareGPT4V, ALLAVA and a part of VFLAN. 21 | 22 | ## Instrcution Tuning 23 | 24 | We use LLaVA-1.5 sft 665k dataset. We would update the results when LLaVA-NExT released. 25 | 26 | ## Prepare Images 27 | 28 | First, download all images and instrcution files. 29 | 30 | - ALLaVA: [images](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V) 31 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip) 32 | - LLaVA: [llava](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 33 | - WebData: [images](https://drive.google.com/drive/folders/1tCUQ-sq6vdshZVkF0ZeF3K4eztkXJgax?usp=sharing). Only for academic usage. 34 | - SAM: [images](https://ai.meta.com/datasets/segment-anything-downloads/). We only use 000000~000050.tar for now. If you find it is slow for you to donnload in China, please refer to [opendatalab](https://opendatalab.com/OpenDataLab/SA-1B) to download it. 35 | - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip) 36 | - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing). We save all files as `.jpg` 37 | - TextVQA: [trainvalimages](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip) 38 | - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip) 39 | - vflan: [vflan](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k) 40 | 41 | Then, organize the data as follows: 42 | 43 | ```none 44 | ShareGPT4V 45 | ├── ... 46 | ├── data 47 | │ ├── allava 48 | │ │ ├── allava_laion 49 | │ │ │ ├── images 50 | │ │ │ ├── ALLaVA-Caption-LAION-4V.json 51 | │ │ │ ├── ALLaVA-Instruct-LAION-4V.json 52 | │ │ ├── allava_vflan 53 | │ │ │ ├── ALLaVA-Caption-VFLAN-4V.json 54 | │ │ │ ├── ALLaVA-Instruct-VFLAN-4V.json 55 | │ ├── coco 56 | │ │ ├── train2017 57 | │ ├── llava 58 | │ │ ├── llava_v1_5_mix665k.json 59 | │ ├── sam 60 | │ │ ├── images 61 | │ ├── gqa 62 | │ │ ├── images 63 | │ ├── ocr_vqa 64 | │ │ ├── images 65 | │ ├── textvqa 66 | │ │ ├── train_images 67 | │ ├── vg 68 | │ │ ├── VG_100K 69 | │ │ ├── VG_100K_2 70 | │ ├── vflan 71 | │ │ ├── images_191task_1k 72 | │ │ ├── annotation_191-task_1k.json 73 | │ ├── sharegpt4v 74 | │ │ ├── share-captioner_coco_lcs_sam_1246k_1107.json 75 | │ │ ├── sharegpt4v_instruct_gpt4-vision_cap100k.json 76 | │ ├── share_textvqa 77 | │ │ ├── images 78 | │ ├── web-celebrity 79 | │ │ ├── images 80 | │ ├── web-landmark 81 | │ │ ├── images 82 | │ ├── wikiart 83 | │ │ ├── images 84 | ├── ... 85 | ``` 86 | 87 | If you find download ocrvqa images slow. You could refer to this [issue](https://github.com/haotian-liu/LLaVA/issues/931). 88 | Use multiprocessing to speed up: 89 | 90 | ```python 91 | import concurrent.futures 92 | def download_image(k): 93 | ext = os.path.splitext(data[k]['imageURL'])[1] 94 | outputFile = 'images/%s%s' % (k, ext) 95 | 96 | # Only download the image if it doesn't exist 97 | if not os.path.exists(outputFile): 98 | ureq.urlretrieve(data[k]['imageURL'], outputFile) 99 | 100 | 101 | if download == 1: 102 | # Create the directory if it doesn't exist 103 | if not os.path.exists('./images'): 104 | os.mkdir('./images') 105 | 106 | # Create a thread pool and download the images in parallel 107 | with concurrent.futures.ThreadPoolExecutor() as executor: 108 | executor.map(download_image, data.keys()) 109 | ``` 110 | 111 | For ocrvqa, some git images should be transfered to jpg. You could follow bwloe code: 112 | 113 | ```python 114 | import os 115 | from PIL import Image 116 | 117 | def convert_gif_to_jpg(folder_path): 118 | for filename in os.listdir(folder_path): 119 | if filename.endswith('.gif'): 120 | file_path = os.path.join(folder_path, filename) 121 | with Image.open(file_path) as img: 122 | jpg_filename = os.path.splitext(filename)[0] + '.jpg' 123 | jpg_path = os.path.join(folder_path, jpg_filename) 124 | img.convert('RGB').save(jpg_path, 'JPEG', quality=95) 125 | print(f'Converted {filename} to {jpg_filename}') 126 | 127 | folder_path = 'path_to_your_folder' 128 | convert_gif_to_jpg(folder_path) 129 | ``` 130 | 131 | ## Data Configuration 132 | 133 | You could modify the file [data.py](conv-llava/llava/data/data_blending.py) to add the datasets. Replace with the true path: 134 | 135 | ```python 136 | def build_sharegpt4v(tokenizer, data_args): 137 | data_path = 'path_to_sharegpt4v_pt.json' 138 | image_folder = 'folder_to_sharegpt4v_pt' 139 | dataset = SampleDataset(data_path, tokenizer, data_args, 140 | image_folder) 141 | return dataset 142 | ``` 143 | -------------------------------------------------------------------------------- /docs/Evaluation.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | ## VLMEvalKit 4 | 5 | We use VLMEVALKIT as the evaluation tools. Please refer to [QuickStart](https://github.com/open-compass/VLMEvalKit/blob/main/Quickstart.md) for installation. 6 | 7 | We evaluate the models with scripts in [evaluation.sh](scripts/evaluation.sh). You could modify the parameters for evaluating different benchmarks. 8 | 9 | You should use the file [run.py](conv-llava/llava/eval/run.py) to replace with the original run file to evaluate the model. 10 | 11 | ```bash 12 | eval_dataset="MMVet" # need openai api key 13 | eval_dataset="MME MMBench_DEV_EN" 14 | 15 | # set the llava-path to the actual path of your convllava checkpoint 16 | ``` 17 | 18 | We would contribute the VLMEVALKIT to support our model soon. 19 | 20 | ## lmms-eval 21 | 22 | If you want to use lmms-eval to evaluate the model. You need to first install the package: 23 | 24 | ```bash 25 | git clone https://github.com/EvolvingLMMs-Lab/lmms-eval 26 | cd lmms-eval 27 | pip install -e . 28 | ``` 29 | 30 | You should use the file [eval-lmms.sh](conv-llava/llava/eval/eval-lmms.sh) to evaluate the model. You could modify the parameters for evaluating different benchmarks. 31 | 32 | 33 | ## RefCOCO 34 | 35 | If you are interested in RefCOCO, we provide the code in [refcoco.sh](scripts/refcoco.sh). -------------------------------------------------------------------------------- /docs/Model_zoo.md: -------------------------------------------------------------------------------- 1 | # Model Zoo 2 | 3 | ## Performance 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 |
MethodResolutionVisual TokensLLMMMEMMBSEEDRealWorldQAMMMUMMVetTextDocPOPE
ConvLLaVA7681447B15416868.855.936.344.859.144.887.3
ConvLLaVA10242567B155368.869.358.835.144.462.548.587.7
ConvLLaVA15365767B157568.770.259.935.845.965.85987.3
68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 |
MethodResolutionVisual TokensLLMRefCOCORefCOCO+RefCOCOgAvg
valtest-Atest-Bvaltest-Atest-Bvaltest
ConvLLaVA7681447B84.589.079.277.784.969.779.879.780.6
ConvLLaVA10242567B85.589.678.879.386.170.380.681.281.4
ConvLLaVA15365767B86.590.680.580.086.871.582.082.482.3
137 | 138 | ## Download 139 | 140 | We release checkpoints after vision language pretraining and visual instruction tuning. You could directly use the sft model and finetune the vision language pretraining checkpoints on you own data. 141 | 142 | | model | Huggingface | ModelScope | WiseModel | 143 | | :------------: | :---------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------- | 144 | | ConvLLaVA-768 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-768), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-768) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-768/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-768/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-768/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-768/intro) | 145 | | ConvLLaVA-1024 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-1024), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-1024) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1024/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-1024/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1024/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-1024/intro) | 146 | | ConvLLaVA-1536 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-1536), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-1536) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1536/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-1536/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1536/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-1536/intro) | 147 | 148 | The **pretrain** above means the checkpoints are after the second stage **vision-language pretraining**. The **sft** above means the checkpoints are after the third stage **instruction tuning**. 149 | 150 | ## Usage of the scripts 151 | 152 | The three stages training scripts are listed below: 153 | 154 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh) 155 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh) 156 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh) 157 | 158 | ### Customize training 159 | 160 | If you want to custimze your model, you can directly load the **second stage pretrained visual encoder and LLM** for instruction tuning. It takes about 5 hours to train the 768 resolution model with LLaVA-Instruct-665k on a single 8 A800 GPUs. 161 | 162 | ### Training from scratch 163 | 164 | If you wang to train from scratch, you could download our processed ConvNeXt model (modify from LAION ConvNeXt). Then follow the three stage training scripts to train the model. 165 | 166 | ConvNeXt: [huggingface](https://huggingface.co/ConvLLaVA/LAION-CLIP-ConvNeXt-Large-512), [modelscope](https://modelscope.cn/models/ConvLLaVA/LAION-CLIP-ConvNeXt-Large-512/summary) 167 | 168 | You need to modify the config from the folder to the resolution you want to train your model on: 169 | 170 | - config.json: image_size 171 | - preprocessor_config: size, crop_size 172 | 173 | Then load that weights and start training. 174 | -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | SINGLE = auto() 12 | TWO = auto() 13 | MPT = auto() 14 | PLAIN = auto() 15 | LLAMA_2 = auto() 16 | 17 | 18 | @dataclasses.dataclass 19 | class Conversation: 20 | """A class that keeps all conversation history.""" 21 | system: str 22 | roles: List[str] 23 | messages: List[List[str]] 24 | offset: int 25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 26 | sep: str = "###" 27 | sep2: str = None 28 | version: str = "Unknown" 29 | 30 | skip_next: bool = False 31 | 32 | def get_prompt(self): 33 | messages = self.messages 34 | if len(messages) > 0 and type(messages[0][1]) is tuple: 35 | messages = self.messages.copy() 36 | init_role, init_msg = messages[0].copy() 37 | init_msg = init_msg[0].replace("", "").strip() 38 | if 'mmtag' in self.version: 39 | messages[0] = (init_role, init_msg) 40 | messages.insert(0, (self.roles[0], "")) 41 | messages.insert(1, (self.roles[1], "Received.")) 42 | else: 43 | messages[0] = (init_role, "\n" + init_msg) 44 | 45 | if self.sep_style == SeparatorStyle.SINGLE: 46 | ret = self.system + self.sep 47 | for role, message in messages: 48 | if message: 49 | if type(message) is tuple: 50 | message, _, _ = message 51 | ret += role + ": " + message + self.sep 52 | else: 53 | ret += role + ":" 54 | elif self.sep_style == SeparatorStyle.TWO: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system + seps[0] 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += role + ": " + message + seps[i % 2] 62 | else: 63 | ret += role + ":" 64 | elif self.sep_style == SeparatorStyle.MPT: 65 | ret = self.system + self.sep 66 | for role, message in messages: 67 | if message: 68 | if type(message) is tuple: 69 | message, _, _ = message 70 | ret += role + message + self.sep 71 | else: 72 | ret += role 73 | elif self.sep_style == SeparatorStyle.LLAMA_2: 74 | def wrap_sys( 75 | msg): return f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 76 | 77 | def wrap_inst(msg): return f"[INST] {msg} [/INST]" 78 | ret = "" 79 | 80 | for i, (role, message) in enumerate(messages): 81 | if i == 0: 82 | assert message, "first message should not be none" 83 | assert role == self.roles[0], "first message should come from user" 84 | if message: 85 | if type(message) is tuple: 86 | message, _, _ = message 87 | if i == 0: 88 | message = wrap_sys(self.system) + message 89 | if i % 2 == 0: 90 | message = wrap_inst(message) 91 | ret += self.sep + message 92 | else: 93 | ret += " " + message + " " + self.sep2 94 | else: 95 | ret += "" 96 | ret = ret.lstrip(self.sep) 97 | elif self.sep_style == SeparatorStyle.PLAIN: 98 | seps = [self.sep, self.sep2] 99 | ret = self.system 100 | for i, (role, message) in enumerate(messages): 101 | if message: 102 | if type(message) is tuple: 103 | message, _, _ = message 104 | ret += message + seps[i % 2] 105 | else: 106 | ret += "" 107 | else: 108 | raise ValueError(f"Invalid style: {self.sep_style}") 109 | 110 | return ret 111 | 112 | def append_message(self, role, message): 113 | self.messages.append([role, message]) 114 | 115 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672): 116 | if image_process_mode == "Pad": 117 | def expand2square(pil_img, background_color=(122, 116, 104)): 118 | width, height = pil_img.size 119 | if width == height: 120 | return pil_img 121 | elif width > height: 122 | result = Image.new( 123 | pil_img.mode, (width, width), background_color) 124 | result.paste(pil_img, (0, (width - height) // 2)) 125 | return result 126 | else: 127 | result = Image.new( 128 | pil_img.mode, (height, height), background_color) 129 | result.paste(pil_img, ((height - width) // 2, 0)) 130 | return result 131 | image = expand2square(image) 132 | elif image_process_mode in ["Default", "Crop"]: 133 | pass 134 | elif image_process_mode == "Resize": 135 | image = image.resize((336, 336)) 136 | else: 137 | raise ValueError( 138 | f"Invalid image_process_mode: {image_process_mode}") 139 | if max(image.size) > max_len: 140 | max_hw, min_hw = max(image.size), min(image.size) 141 | aspect_ratio = max_hw / min_hw 142 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 143 | longest_edge = int(shortest_edge * aspect_ratio) 144 | W, H = image.size 145 | if H > W: 146 | H, W = longest_edge, shortest_edge 147 | else: 148 | H, W = shortest_edge, longest_edge 149 | image = image.resize((W, H)) 150 | if return_pil: 151 | return image 152 | else: 153 | buffered = BytesIO() 154 | image.save(buffered, format=image_format) 155 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 156 | return img_b64_str 157 | 158 | def get_images(self, return_pil=False): 159 | images = [] 160 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 161 | if i % 2 == 0: 162 | if type(msg) is tuple: 163 | msg, image, image_process_mode = msg 164 | image = self.process_image( 165 | image, image_process_mode, return_pil=return_pil) 166 | images.append(image) 167 | return images 168 | 169 | def to_gradio_chatbot(self): 170 | ret = [] 171 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 172 | if i % 2 == 0: 173 | if type(msg) is tuple: 174 | msg, image, image_process_mode = msg 175 | img_b64_str = self.process_image( 176 | image, "Default", return_pil=False, 177 | image_format='JPEG') 178 | img_str = f'user upload image' 179 | msg = img_str + msg.replace('', '').strip() 180 | ret.append([msg, None]) 181 | else: 182 | ret.append([msg, None]) 183 | else: 184 | ret[-1][-1] = msg 185 | return ret 186 | 187 | def copy(self): 188 | return Conversation( 189 | system=self.system, 190 | roles=self.roles, 191 | messages=[[x, y] for x, y in self.messages], 192 | offset=self.offset, 193 | sep_style=self.sep_style, 194 | sep=self.sep, 195 | sep2=self.sep2, 196 | version=self.version) 197 | 198 | def dict(self): 199 | if len(self.get_images()) > 0: 200 | return { 201 | "system": self.system, 202 | "roles": self.roles, 203 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 204 | "offset": self.offset, 205 | "sep": self.sep, 206 | "sep2": self.sep2, 207 | } 208 | return { 209 | "system": self.system, 210 | "roles": self.roles, 211 | "messages": self.messages, 212 | "offset": self.offset, 213 | "sep": self.sep, 214 | "sep2": self.sep2, 215 | } 216 | 217 | 218 | conv_vicuna_v0 = Conversation( 219 | system="A chat between a curious human and an artificial intelligence assistant. " 220 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 221 | roles=("Human", "Assistant"), 222 | messages=( 223 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 224 | ("Assistant", 225 | "Renewable energy sources are those that can be replenished naturally in a relatively " 226 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 227 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 228 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 229 | "renewable and non-renewable energy sources:\n" 230 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 231 | "energy sources are finite and will eventually run out.\n" 232 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 233 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 234 | "and other negative effects.\n" 235 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 236 | "have lower operational costs than non-renewable sources.\n" 237 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 238 | "locations than non-renewable sources.\n" 239 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 240 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 241 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 242 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 243 | ), 244 | offset=2, 245 | sep_style=SeparatorStyle.SINGLE, 246 | sep="###", 247 | ) 248 | 249 | conv_vicuna_v1 = Conversation( 250 | system="A chat between a curious user and an artificial intelligence assistant. " 251 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 252 | roles=("USER", "ASSISTANT"), 253 | version="v1", 254 | messages=(), 255 | offset=0, 256 | sep_style=SeparatorStyle.TWO, 257 | sep=" ", 258 | sep2="", 259 | ) 260 | 261 | conv_llama_2 = Conversation( 262 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. 263 | 264 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""", 265 | roles=("USER", "ASSISTANT"), 266 | version="llama_v2", 267 | messages=(), 268 | offset=0, 269 | sep_style=SeparatorStyle.LLAMA_2, 270 | sep="", 271 | sep2="", 272 | ) 273 | 274 | conv_llava_llama_2 = Conversation( 275 | system="You are a helpful language and vision assistant. " 276 | "You are able to understand the visual content that the user provides, " 277 | "and assist the user with a variety of tasks using natural language.", 278 | roles=("USER", "ASSISTANT"), 279 | version="llama_v2", 280 | messages=(), 281 | offset=0, 282 | sep_style=SeparatorStyle.LLAMA_2, 283 | sep="", 284 | sep2="", 285 | ) 286 | 287 | conv_mpt = Conversation( 288 | system="""<|im_start|>system 289 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""", 290 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 291 | version="mpt", 292 | messages=(), 293 | offset=0, 294 | sep_style=SeparatorStyle.MPT, 295 | sep="<|im_end|>", 296 | ) 297 | 298 | conv_llava_plain = Conversation( 299 | system="", 300 | roles=("", ""), 301 | messages=( 302 | ), 303 | offset=0, 304 | sep_style=SeparatorStyle.PLAIN, 305 | sep="\n", 306 | ) 307 | 308 | conv_llava_v0 = Conversation( 309 | system="A chat between a curious human and an artificial intelligence assistant. " 310 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 311 | roles=("Human", "Assistant"), 312 | messages=( 313 | ), 314 | offset=0, 315 | sep_style=SeparatorStyle.SINGLE, 316 | sep="###", 317 | ) 318 | 319 | conv_llava_v0_mmtag = Conversation( 320 | system="A chat between a curious user and an artificial intelligence assistant. " 321 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 322 | "The visual content will be provided with the following format: visual content.", 323 | roles=("Human", "Assistant"), 324 | messages=( 325 | ), 326 | offset=0, 327 | sep_style=SeparatorStyle.SINGLE, 328 | sep="###", 329 | version="v0_mmtag", 330 | ) 331 | 332 | conv_llava_v1 = Conversation( 333 | system="A chat between a curious human and an artificial intelligence assistant. " 334 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 335 | roles=("USER", "ASSISTANT"), 336 | version="v1", 337 | messages=(), 338 | offset=0, 339 | sep_style=SeparatorStyle.TWO, 340 | sep=" ", 341 | sep2="", 342 | ) 343 | 344 | conv_llava_v1_mmtag = Conversation( 345 | system="A chat between a curious user and an artificial intelligence assistant. " 346 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 347 | "The visual content will be provided with the following format: visual content.", 348 | roles=("USER", "ASSISTANT"), 349 | messages=(), 350 | offset=0, 351 | sep_style=SeparatorStyle.TWO, 352 | sep=" ", 353 | sep2="", 354 | version="v1_mmtag", 355 | ) 356 | 357 | default_conversation = conv_vicuna_v1 358 | conv_templates = { 359 | "default": conv_vicuna_v0, 360 | "v1": conv_vicuna_v1, 361 | "vicuna_v1": conv_vicuna_v1, 362 | "llama_2": conv_llama_2, 363 | 364 | "plain": conv_llava_plain, 365 | "v0_plain": conv_llava_plain, 366 | "llava_v0": conv_llava_v0, 367 | "v0_mmtag": conv_llava_v0_mmtag, 368 | "llava_v1": conv_llava_v1, 369 | "v1_mmtag": conv_llava_v1_mmtag, 370 | "llava_llama_2": conv_llava_llama_2, 371 | 372 | "mpt": conv_mpt 373 | } 374 | 375 | 376 | if __name__ == "__main__": 377 | print(default_conversation.get_prompt()) 378 | -------------------------------------------------------------------------------- /llava/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/data/__init__.py -------------------------------------------------------------------------------- /llava/eval/evaluate_grounding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import re 7 | import time 8 | from functools import partial 9 | 10 | import torch 11 | from torchvision.transforms.functional import InterpolationMode 12 | import torchvision.transforms as T 13 | from PIL import Image 14 | from torchvision.ops.boxes import box_area 15 | from tqdm import tqdm 16 | from llava.mm_utils import tokenizer_image_token, process_images 17 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 18 | from llava.conversation import conv_templates 19 | from llava.model.builder import load_pretrained_model 20 | 21 | 22 | def expand2square(pil_img, background_color): 23 | width, height = pil_img.size 24 | if width == height: 25 | return pil_img 26 | elif width > height: 27 | result = Image.new(pil_img.mode, (width, width), background_color) 28 | result.paste(pil_img, (0, (width - height) // 2)) 29 | return result 30 | else: 31 | result = Image.new(pil_img.mode, (height, height), background_color) 32 | result.paste(pil_img, ((height - width) // 2, 0)) 33 | return result 34 | 35 | 36 | def build_transform(is_train, input_size, pad2square=False): 37 | if is_train: 38 | transform = T.Compose([ 39 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 40 | T.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(3. / 4., 4. / 3.), 41 | interpolation=InterpolationMode.BICUBIC), 42 | T.ToTensor(), 43 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 44 | ]) 45 | else: 46 | if pad2square is False: 47 | transform = T.Compose([ 48 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 49 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 50 | T.ToTensor(), 51 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 52 | ]) 53 | else: 54 | transform = T.Compose([ 55 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 56 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in (0.485, 0.456, 0.406)))), 57 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), 58 | T.ToTensor(), 59 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) 60 | ]) 61 | 62 | return transform 63 | 64 | 65 | ds_collections = { 66 | 'refcoco_val': 'data/refcoco/refcoco_val.jsonl', 67 | 'refcoco_testA': 'data/refcoco/refcoco_testA.jsonl', 68 | 'refcoco_testB': 'data/refcoco/refcoco_testB.jsonl', 69 | 'refcoco+_val': 'data/refcoco/refcoco+_val.jsonl', 70 | 'refcoco+_testA': 'data/refcoco/refcoco+_testA.jsonl', 71 | 'refcoco+_testB': 'data/refcoco/refcoco+_testB.jsonl', 72 | 'refcocog_val': 'data/refcoco/refcocog_val.jsonl', 73 | 'refcocog_test': 'data/refcoco/refcocog_test.jsonl', 74 | } 75 | 76 | 77 | def reserve_square_bbox(box, w, h): 78 | if w == h: 79 | return box 80 | box = box.tolist()[0] 81 | if w > h: 82 | x1, y1, x2, y2 = box 83 | y1 -= (w - h) // 2 84 | y2 -= (w - h) // 2 85 | box = [[x1, y1, x2, y2]] 86 | return torch.tensor(box).resize(1, 4) 87 | else: 88 | x1, y1, x2, y2 = box 89 | x1 -= (h - w) // 2 90 | x2 -= (h - w) // 2 91 | box = [[x1, y1, x2, y2]] 92 | return torch.tensor(box).resize(1, 4) 93 | 94 | 95 | class ModelConfig: 96 | def __init__(self, image_aspect_ratio=None): 97 | self.image_aspect_ratio = image_aspect_ratio 98 | 99 | def box_iou(boxes1, boxes2): 100 | area1 = box_area(boxes1) 101 | area2 = box_area(boxes2) 102 | 103 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 104 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 105 | 106 | wh = (rb - lt).clamp(min=0) # [N,M,2] 107 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 108 | 109 | union = area1[:, None] + area2 - inter 110 | 111 | iou = inter / union 112 | return iou, union 113 | 114 | def collate_fn(batch, tokenzier=None): 115 | input_ids, image_tensors, bbox, hw, image_path, text = zip(*batch) 116 | input_ids = torch.stack(input_ids, dim=0) 117 | image_tensors = torch.stack(image_tensors, dim=0) 118 | return input_ids, image_tensors, bbox, hw, image_path, text 119 | 120 | 121 | class RefCOCODataset(torch.utils.data.Dataset): 122 | 123 | def __init__(self, test, prompt, input_size=224, pad2square=False, image_processor=None, model_cfg=None, tokenizer=None): 124 | self.datas = open(test).readlines() 125 | self.prompt = prompt 126 | self.transform = build_transform(is_train=False, input_size=input_size, pad2square=pad2square) 127 | self.image_processor = image_processor 128 | self.model_config = model_cfg 129 | self.tokenizer = tokenizer 130 | 131 | def __len__(self): 132 | return len(self.datas) 133 | 134 | def __getitem__(self, idx): 135 | data = json.loads(self.datas[idx].strip()) 136 | image_path = data['image'] 137 | text = data['sent'] 138 | bbox = data['bbox'] 139 | 140 | w, h = data['width'], data['height'] 141 | image = os.path.join('/mnt/thuair/gcjtcl/InternVL', image_path) 142 | 143 | image = Image.open(image).convert('RGB') 144 | # pixel_values = self.transform(image).unsqueeze(0) 145 | pixel_values = process_images([image], self.image_processor, self.model_config)[0] 146 | prompt = self.prompt.format(text) 147 | prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt 148 | conv = conv_templates["v1"].copy() 149 | conv.append_message(conv.roles[0], prompt) 150 | conv.append_message(conv.roles[1], None) 151 | prompt = conv.get_prompt() 152 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') 153 | return input_ids, pixel_values, bbox, (h, w), image_path, text 154 | 155 | 156 | class InferenceSampler(torch.utils.data.sampler.Sampler): 157 | 158 | def __init__(self, size): 159 | self._size = int(size) 160 | assert size > 0 161 | self._rank = torch.distributed.get_rank() 162 | self._world_size = torch.distributed.get_world_size() 163 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 164 | 165 | @staticmethod 166 | def _get_local_indices(total_size, world_size, rank): 167 | shard_size = total_size // world_size 168 | left = total_size % world_size 169 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 170 | 171 | begin = sum(shard_sizes[:rank]) 172 | end = min(sum(shard_sizes[:rank + 1]), total_size) 173 | return range(begin, end) 174 | 175 | def __iter__(self): 176 | yield from self._local_indices 177 | 178 | def __len__(self): 179 | return len(self._local_indices) 180 | 181 | 182 | def evaluate_chat_model(): 183 | print('prompt:', prompt) 184 | random.seed(args.seed) 185 | summaries = [] 186 | 187 | for ds_name in args.datasets: 188 | dataset = RefCOCODataset( 189 | test=os.path.join("/mnt/thuair/gcjtcl/InternVL", ds_collections[ds_name]), 190 | prompt=prompt, 191 | input_size=image_size, 192 | pad2square=pad2square, 193 | image_processor=image_processor, 194 | model_cfg=model_cfg, 195 | tokenizer=tokenizer 196 | ) 197 | dataloader = torch.utils.data.DataLoader( 198 | dataset=dataset, 199 | sampler=InferenceSampler(len(dataset)), 200 | batch_size=args.batch_size, 201 | num_workers=args.num_workers, 202 | pin_memory=True, 203 | drop_last=False, 204 | collate_fn=partial(collate_fn), 205 | ) 206 | 207 | outputs = [] 208 | for _, (questions, pixel_values, bboxes, hws, image_path, text) in enumerate(tqdm(dataloader)): 209 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 210 | output_ids = model.generate( 211 | questions.to(device='cuda', non_blocking=True), 212 | images=pixel_values.to(dtype=torch.float16), 213 | do_sample=False, 214 | temperature=0, 215 | max_new_tokens=100) 216 | 217 | pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 218 | answers = [pred] 219 | 220 | for bbox, hw, answer in zip(bboxes, hws, answers): 221 | outputs.append({ 222 | 'image_path': image_path, 223 | 'text': text, 224 | 'answer': answer, 225 | 'gt_bbox': bbox, 226 | 'hw': hw, 227 | }) 228 | 229 | torch.distributed.barrier() 230 | 231 | world_size = torch.distributed.get_world_size() 232 | merged_outputs = [None for _ in range(world_size)] 233 | torch.distributed.all_gather_object(merged_outputs, outputs) 234 | 235 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 236 | 237 | if torch.distributed.get_rank() == 0: 238 | print(f'Evaluating {ds_name} ...') 239 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime()) 240 | results_file = f'{ds_name}_{time_prefix}.json' 241 | results_file = os.path.join(args.out_dir, results_file) 242 | json.dump(merged_outputs, open(results_file, 'w')) 243 | # with open("/mnt/thuair/gcjtcl/InternVL/internvl_chat/conv768/refcocog_val_240419174233.json", 'r') as f: 244 | # merged_outputs = json.load(f) 245 | 246 | correct = total_cnt = 0 247 | for i, output in enumerate(merged_outputs): 248 | predict_bbox = re.findall(PATTERN, output['answer']) 249 | try: 250 | predict_bbox = (float(predict_bbox[0][0]), float(predict_bbox[0][1]), float(predict_bbox[0][2]), 251 | float(predict_bbox[0][3])) 252 | except: 253 | predict_bbox = (0., 0., 0., 0.) 254 | target_bbox = torch.tensor(output['gt_bbox'], 255 | dtype=torch.float32).view(-1, 4) 256 | predict_bbox = torch.tensor(predict_bbox, 257 | dtype=torch.float32).view(-1, 4) 258 | if predict_bbox.sum() >= 4: 259 | predict_bbox = predict_bbox / 1000 260 | predict_bbox *= max(output['hw']) 261 | w, h = output['hw'][1], output['hw'][0] 262 | predict_bbox = reserve_square_bbox(predict_bbox, w, h) 263 | # print(predict_bbox) 264 | # predict_bbox[:, 0::2] *= output['hw'][1] 265 | # predict_bbox[:, 1::2] *= output['hw'][0] 266 | iou, _ = box_iou(predict_bbox, target_bbox) 267 | iou = iou.item() 268 | total_cnt += 1 269 | if iou >= 0.5: 270 | correct += 1 271 | 272 | print(f'Evaluating {ds_name} ...') 273 | print(f'Precision @ 1: {correct / total_cnt} \n') 274 | summaries.append([args.checkpoint, ds_name, f'Precision @ 1: {correct / total_cnt} \n']) 275 | 276 | torch.distributed.barrier() 277 | 278 | out_path = '_'.join(args.checkpoint.split('/')[-2:]) 279 | writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a') 280 | print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}") 281 | for summary in summaries: 282 | print(summary) 283 | writer.write(f'{summary}\n') 284 | writer.close() 285 | 286 | 287 | if __name__ == '__main__': 288 | 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument('--checkpoint', type=str, default='') 291 | parser.add_argument('--datasets', type=str, default='refcoco_val,refcoco_testA,refcoco_testB,' 292 | 'refcoco+_val,refcoco+_testA,refcoco+_testB,' 293 | 'refcocog_val,refcocog_test') 294 | parser.add_argument('--batch-size', type=int, default=1) 295 | parser.add_argument('--num-workers', type=int, default=1) 296 | parser.add_argument('--num-beams', type=int, default=5) 297 | parser.add_argument('--out-dir', type=str, default='results') 298 | parser.add_argument('--sample', type=bool, default=False) 299 | parser.add_argument('--temperature', type=float, default=0.0) 300 | parser.add_argument('--seed', type=int, default=0) 301 | args = parser.parse_args() 302 | 303 | 304 | args.datasets = args.datasets.split(',') 305 | print('datasets:', args.datasets) 306 | assert args.batch_size == 1, 'Only batch size 1 is supported' 307 | 308 | torch.distributed.init_process_group( 309 | backend='nccl', 310 | world_size=int(os.getenv('WORLD_SIZE', '1')), 311 | rank=int(os.getenv('RANK', '0')), 312 | ) 313 | rank = torch.distributed.get_rank() 314 | if rank == 0: 315 | if not os.path.exists(args.out_dir): 316 | os.makedirs(args.out_dir) 317 | 318 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 319 | torch.cuda.set_device(int(os.getenv('RANK', 0))) 320 | print(f"rank: {int(os.getenv('RANK', 0))}") 321 | device = torch.device(int(os.getenv('RANK', 0))) 322 | 323 | # tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint) 324 | PATTERN = re.compile(r'\[*\[(.*?),(.*?),(.*?),(.*?)\]\]*') 325 | 326 | # 创建一个model_cfg对象 327 | model_cfg = ModelConfig(image_aspect_ratio="pad") 328 | 329 | 330 | # device = 331 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.checkpoint, None, "llava", device='cpu', device_map='cpu') 332 | model = model.to(device).eval() 333 | vision_tower = model.get_vision_tower().to(device) 334 | model.get_model().mm_projector.to(device) 335 | # model.cuda() 336 | # for p in model.parameters(): 337 | # p.cuda() 338 | image_size = 336 339 | pad2square = True 340 | prompt = 'Please provide the bounding box coordinate of the region this sentence describes: {}.' 341 | 342 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 343 | if total_params > 30: 344 | args.num_beams = 1 345 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 346 | else: 347 | print(f'[test] total_params: {total_params}B') 348 | print(f'[test] image_size: {image_size}') 349 | print(f'[test] pad2square: {pad2square}') 350 | print(f'[test] template: v1') 351 | 352 | evaluate_chat_model() 353 | -------------------------------------------------------------------------------- /llava/eval/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | from vlmeval.smp import * 4 | from vlmeval.evaluate import * 5 | from vlmeval.inference import infer_data_job 6 | from vlmeval.config import supported_VLM 7 | from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full, MMMU_result_transfer 8 | from functools import partial 9 | from vlmeval.vlm import LLaVA 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--data', type=str, nargs='+', required=True) 15 | parser.add_argument('--model', type=str, nargs='+', required=True) 16 | parser.add_argument('--work-dir', type=str, default='.', help='select the output directory') 17 | parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer']) 18 | parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling') 19 | parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs') 20 | parser.add_argument('--judge', type=str, default=None) 21 | parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ') 22 | parser.add_argument('--verbose', action='store_true') 23 | parser.add_argument("--llava-path", type=str, 24 | default='liuhaotian/llava-v1.5-7b') 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def main(): 30 | logger = get_logger('RUN') 31 | 32 | args = parse_args() 33 | assert len(args.data), "--data should be a list of data files" 34 | 35 | supported_VLM.update( 36 | {"llava_v1.5_7b": partial(LLaVA, model_pth=args.llava_path)}) 37 | 38 | if args.retry is not None: 39 | for k, v in supported_VLM.items(): 40 | if hasattr(v, 'keywords') and 'retry' in v.keywords: 41 | v.keywords['retry'] = args.retry 42 | supported_VLM[k] = v 43 | if hasattr(v, 'keywords') and 'verbose' in v.keywords: 44 | v.keywords['verbose'] = args.verbose 45 | supported_VLM[k] = v 46 | 47 | rank, world_size = get_rank_and_world_size() 48 | if world_size > 1: 49 | torch.cuda.set_device(rank) 50 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=10800)) 51 | 52 | for _, model_name in enumerate(args.model): 53 | model = None 54 | 55 | pred_root = osp.join(args.work_dir, model_name) 56 | os.makedirs(pred_root, exist_ok=True) 57 | 58 | for _, dataset_name in enumerate(args.data): 59 | custom_flag = False 60 | 61 | if dataset_name not in dataset_URLs: 62 | dataset_name = abbr2full(dataset_name) 63 | 64 | if dataset_name not in dataset_URLs: 65 | logger.warning(f'Dataset {dataset_name} is not officially supported. ') 66 | file_path = osp.join(LMUDataRoot(), f'{dataset_name}.tsv') 67 | if not osp.exists(file_path): 68 | logger.error(f'Cannot find the local dataset {dataset_name}. ') 69 | continue 70 | else: 71 | custom_flag = True 72 | 73 | result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx' 74 | 75 | if model is None: 76 | model = model_name # which is only a name 77 | 78 | model = infer_data_job( 79 | model, 80 | work_dir=pred_root, 81 | model_name=model_name, 82 | dataset_name=dataset_name, 83 | verbose=args.verbose, 84 | api_nproc=args.nproc, 85 | ignore_failed=args.ignore) 86 | 87 | if rank == 0: 88 | if dataset_name in ['MMMU_TEST']: 89 | result_json = MMMU_result_transfer(result_file) 90 | logger.info(f'Transfer MMMU_TEST result to json for official evaluation, json file saved in {result_json}') # noqa: E501 91 | 92 | if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMMU_TEST']: 93 | if not MMBenchOfficialServer(): 94 | logger.error( 95 | f'Can not evaluate {dataset_name} on non-official servers, ' 96 | 'will skip the evaluation. ' 97 | ) 98 | continue 99 | 100 | judge_kwargs = { 101 | 'nproc': args.nproc, 102 | 'verbose': args.verbose, 103 | } 104 | if args.retry is not None: 105 | judge_kwargs['retry'] = args.retry 106 | if args.judge is not None: 107 | judge_kwargs['model'] = args.judge 108 | else: 109 | if DATASET_TYPE(dataset_name) in ['multi-choice', 'Y/N']: 110 | judge_kwargs['model'] = 'chatgpt-0613' 111 | elif listinstr(['MMVet', 'MathVista', 'LLaVABench'], dataset_name): 112 | judge_kwargs['model'] = 'gpt-4-turbo' 113 | if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']): 114 | judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE'] 115 | if 'OPENAI_API_BASE_JUDGE' in os.environ and len(os.environ['OPENAI_API_BASE_JUDGE']): 116 | judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE'] 117 | 118 | if rank == 0 and args.mode == 'all': 119 | if DATASET_TYPE(dataset_name) == 'multi-choice': 120 | dataset_name = 'default' if custom_flag else dataset_name 121 | multiple_choice_eval( 122 | result_file, 123 | dataset=dataset_name, 124 | **judge_kwargs) 125 | 126 | elif DATASET_TYPE(dataset_name) == 'Y/N': 127 | YOrN_eval( 128 | result_file, 129 | dataset=dataset_name, 130 | **judge_kwargs) 131 | 132 | elif DATASET_TYPE(dataset_name) == 'Caption': 133 | COCO_eval(result_file) 134 | elif dataset_name == 'MMVet': 135 | MMVet_eval(result_file, **judge_kwargs) 136 | elif dataset_name == 'OCRBench': 137 | OCRBench_eval(result_file) 138 | elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA', 'InfoVQA'], dataset_name): 139 | VQAEval(result_file, dataset_name) 140 | elif listinstr(['MathVista'], dataset_name): 141 | MathVista_eval(result_file, **judge_kwargs) 142 | elif listinstr(['LLaVABench'], dataset_name): 143 | LLaVABench_eval(result_file, **judge_kwargs) 144 | else: 145 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ') 146 | 147 | 148 | if __name__ == '__main__': 149 | load_env() 150 | main() 151 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from llava.constants import IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 187 | 188 | def insert_separator(X, sep): 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 190 | 191 | input_ids = [] 192 | offset = 0 193 | # prompt_chunks[0] 是第一个chunk,如果图在最前面,len(promt_chunks[0]) =0 194 | # 如果不是0而且是[bos]说明可能是交错数据 195 | # 正常的image text pair会是offset=0 196 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 197 | offset = 1 198 | input_ids.append(prompt_chunks[0][0]) 199 | 200 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 201 | input_ids.extend(x[offset:]) 202 | 203 | if return_tensors is not None: 204 | if return_tensors == 'pt': 205 | return torch.tensor(input_ids, dtype=torch.long) 206 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 207 | return input_ids 208 | 209 | 210 | def get_model_name_from_path(model_path): 211 | model_path = model_path.strip("/") 212 | model_paths = model_path.split("/") 213 | if model_paths[-1].startswith('checkpoint-'): 214 | return model_paths[-2] + "_" + model_paths[-1] 215 | else: 216 | return model_paths[-1] 217 | 218 | class KeywordsStoppingCriteria(StoppingCriteria): 219 | def __init__(self, keywords, tokenizer, input_ids): 220 | self.keywords = keywords 221 | self.keyword_ids = [] 222 | self.max_keyword_len = 0 223 | for keyword in keywords: 224 | cur_keyword_ids = tokenizer(keyword).input_ids 225 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 226 | cur_keyword_ids = cur_keyword_ids[1:] 227 | if len(cur_keyword_ids) > self.max_keyword_len: 228 | self.max_keyword_len = len(cur_keyword_ids) 229 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 230 | self.tokenizer = tokenizer 231 | self.start_len = input_ids.shape[1] 232 | 233 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 234 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 235 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 236 | for keyword_id in self.keyword_ids: 237 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 238 | if torch.equal(truncated_output_ids, keyword_id): 239 | return True 240 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 241 | for keyword in self.keywords: 242 | if keyword in outputs: 243 | return True 244 | return False 245 | 246 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 247 | outputs = [] 248 | for i in range(output_ids.shape[0]): 249 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 250 | return all(outputs) 251 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | import subprocess 20 | 21 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 22 | import torch 23 | from llava.model import * 24 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | 27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs): 28 | kwargs = {"device_map": device_map, **kwargs} 29 | 30 | if device != "cuda": 31 | kwargs['device_map'] = {"": device} 32 | 33 | if load_8bit: 34 | kwargs['load_in_8bit'] = True 35 | elif load_4bit: 36 | kwargs['load_in_4bit'] = True 37 | kwargs['quantization_config'] = BitsAndBytesConfig( 38 | load_in_4bit=True, 39 | bnb_4bit_compute_dtype=torch.float16, 40 | bnb_4bit_use_double_quant=True, 41 | bnb_4bit_quant_type='nf4' 42 | ) 43 | else: 44 | kwargs['torch_dtype'] = torch.float16 45 | 46 | tokenizer = AutoTokenizer.from_pretrained( 47 | model_path, use_fast=False) 48 | model = LlavaLlamaForCausalLM.from_pretrained( 49 | model_path, low_cpu_mem_usage=True, **kwargs) 50 | 51 | image_processor = None 52 | 53 | mm_use_im_start_end = getattr( 54 | model.config, "mm_use_im_start_end", False) 55 | mm_use_im_patch_token = getattr( 56 | model.config, "mm_use_im_patch_token", True) 57 | if mm_use_im_patch_token: 58 | tokenizer.add_tokens( 59 | [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 60 | if mm_use_im_start_end: 61 | tokenizer.add_tokens( 62 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 63 | model.resize_token_embeddings(len(tokenizer)) 64 | 65 | vision_tower = model.get_vision_tower() 66 | if not vision_tower.is_loaded: 67 | vision_tower.load_model() 68 | print(model.get_vision_tower().is_loaded) 69 | vision_tower.to(device=device, dtype=torch.float16) 70 | image_processor = vision_tower.image_processor 71 | 72 | if hasattr(model.config, "max_sequence_length"): 73 | context_len = model.config.max_sequence_length 74 | else: 75 | context_len = 2048 76 | 77 | return tokenizer, model, image_processor, context_len 78 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def print_trainable_parameters(self): 58 | for i, layer in enumerate(self.model.layers): 59 | print(f"LLM Layer {i}") 60 | is_trainable = any( 61 | param.requires_grad for param in layer.parameters()) 62 | print(f"LLM Layer {i} is Trainable: {is_trainable}") 63 | for name, param in self.named_parameters(): 64 | # if param.requires_grad: 65 | print(name) 66 | 67 | def forward( 68 | self, 69 | input_ids: torch.LongTensor = None, 70 | attention_mask: Optional[torch.Tensor] = None, 71 | position_ids: Optional[torch.LongTensor] = None, 72 | past_key_values: Optional[List[torch.FloatTensor]] = None, 73 | inputs_embeds: Optional[torch.FloatTensor] = None, 74 | labels: Optional[torch.LongTensor] = None, 75 | use_cache: Optional[bool] = None, 76 | output_attentions: Optional[bool] = None, 77 | output_hidden_states: Optional[bool] = None, 78 | images: Optional[torch.FloatTensor] = None, 79 | image_sizes: Optional[List[List[int]]] = None, 80 | return_dict: Optional[bool] = None, 81 | cache_position: Optional[torch.LongTensor] = None, 82 | ) -> Union[Tuple, CausalLMOutputWithPast]: 83 | 84 | if inputs_embeds is None: 85 | ( 86 | input_ids, 87 | position_ids, 88 | attention_mask, 89 | past_key_values, 90 | inputs_embeds, 91 | labels 92 | ) = self.prepare_inputs_labels_for_multimodal( 93 | input_ids, 94 | position_ids, 95 | attention_mask, 96 | past_key_values, 97 | labels, 98 | images, 99 | image_sizes 100 | ) 101 | 102 | return super().forward( 103 | input_ids=input_ids, 104 | attention_mask=attention_mask, 105 | position_ids=position_ids, 106 | past_key_values=past_key_values, 107 | inputs_embeds=inputs_embeds, 108 | labels=labels, 109 | use_cache=use_cache, 110 | output_attentions=output_attentions, 111 | output_hidden_states=output_hidden_states, 112 | return_dict=return_dict 113 | ) 114 | 115 | @torch.no_grad() 116 | def generate( 117 | self, 118 | inputs: Optional[torch.Tensor] = None, 119 | images: Optional[torch.Tensor] = None, 120 | image_sizes: Optional[torch.Tensor] = None, 121 | **kwargs, 122 | ) -> Union[GenerateOutput, torch.LongTensor]: 123 | position_ids = kwargs.pop("position_ids", None) 124 | attention_mask = kwargs.pop("attention_mask", None) 125 | if "inputs_embeds" in kwargs: 126 | raise NotImplementedError("`inputs_embeds` is not supported") 127 | 128 | if images is not None: 129 | ( 130 | inputs, 131 | position_ids, 132 | attention_mask, 133 | _, 134 | inputs_embeds, 135 | _ 136 | ) = self.prepare_inputs_labels_for_multimodal( 137 | inputs, 138 | position_ids, 139 | attention_mask, 140 | None, 141 | None, 142 | images, 143 | image_sizes=image_sizes 144 | ) 145 | else: 146 | inputs_embeds = self.get_model().embed_tokens(inputs) 147 | 148 | return super().generate( 149 | position_ids=position_ids, 150 | attention_mask=attention_mask, 151 | inputs_embeds=inputs_embeds, 152 | **kwargs 153 | ) 154 | 155 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 156 | inputs_embeds=None, **kwargs): 157 | images = kwargs.pop("images", None) 158 | image_sizes = kwargs.pop("image_sizes", None) 159 | inputs = super().prepare_inputs_for_generation( 160 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 161 | ) 162 | if images is not None: 163 | inputs['images'] = images 164 | if image_sizes is not None: 165 | inputs['image_sizes'] = image_sizes 166 | return inputs 167 | 168 | AutoConfig.register("llava_llama", LlavaConfig) 169 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 170 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | from .convnext_encoder import ConvNeXtCLIPVisionTower 4 | from .siglip_encoder import SiglipVisionTower 5 | from .lknet_encoder import LKNetCLIPVisionTower 6 | 7 | 8 | def build_vision_tower(vision_tower_cfg, **kwargs): 9 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 10 | print(f"now we are building vision tower, the model is {vision_tower}") 11 | if 'siglip' in vision_tower: 12 | print(f'building SiglipVisionTower') 13 | return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | if vision_tower.startswith("openai") or 'clip-vit' in vision_tower: 15 | print(f'building CLIPVisionTower') 16 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 17 | if 'convnext' in vision_tower: 18 | print(f'building ConvNeXtCLIPVisionTower') 19 | return ConvNeXtCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 20 | if 'lknet' in vision_tower.lower(): 21 | print(f'building lknet') 22 | return LKNetCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 23 | return ConvNeXtCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 24 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr( 16 | args, 'mm_vision_select_feature', 'patch') 17 | 18 | if not delay_load: 19 | self.load_model() 20 | else: 21 | self.cfg_only = CLIPVisionConfig.from_pretrained( 22 | self.vision_tower_name) 23 | 24 | def load_model(self): 25 | self.image_processor = CLIPImageProcessor.from_pretrained( 26 | self.vision_tower_name) 27 | self.vision_tower = CLIPVisionModel.from_pretrained( 28 | self.vision_tower_name) 29 | self.vision_tower.requires_grad_(False) 30 | 31 | self.is_loaded = True 32 | 33 | def print_trainable_parameters(self): 34 | for i, layer in enumerate(self.vision_tower.vision_model.encoder.layers): 35 | is_trainable = any( 36 | param.requires_grad for param in layer.parameters()) 37 | print(f"ViT Layer {i} is trainable: {is_trainable}") 38 | 39 | def feature_select(self, image_forward_outs): 40 | image_features = image_forward_outs.hidden_states[self.select_layer] 41 | if self.select_feature == 'patch': 42 | image_features = image_features[:, 1:] 43 | elif self.select_feature == 'cls_patch': 44 | image_features = image_features 45 | else: 46 | raise ValueError( 47 | f'Unexpected select feature: {self.select_feature}') 48 | return image_features 49 | 50 | def forward(self, images): 51 | if type(images) is list: 52 | image_features = [] 53 | for image in images: 54 | image_forward_out = self.vision_tower(image.to( 55 | device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 56 | image_feature = self.feature_select( 57 | image_forward_out).to(image.dtype) 58 | image_features.append(image_feature) 59 | else: 60 | image_forward_outs = self.vision_tower( 61 | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 62 | image_features = self.feature_select( 63 | image_forward_outs).to(images.dtype) 64 | 65 | return image_features 66 | 67 | @property 68 | def dummy_feature(self): 69 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 70 | 71 | @property 72 | def dtype(self): 73 | return self.vision_tower.dtype 74 | 75 | @property 76 | def device(self): 77 | return self.vision_tower.device 78 | 79 | @property 80 | def config(self): 81 | if self.is_loaded: 82 | return self.vision_tower.config 83 | else: 84 | return self.cfg_only 85 | 86 | @property 87 | def hidden_size(self): 88 | return self.config.hidden_size 89 | 90 | @property 91 | def num_patches_per_side(self): 92 | return self.config.image_size // self.config.patch_size 93 | 94 | @property 95 | def num_patches(self): 96 | return (self.config.image_size // self.config.patch_size) ** 2 -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/convnext_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | from transformers import CLIPImageProcessor 5 | from transformers import ConvNextModel, ConvNextConfig 6 | from transformers.models.convnext.modeling_convnext import ConvNextStage 7 | 8 | 9 | class ConvNeXtCLIPVisionTower(nn.Module): 10 | def __init__(self, vision_tower, args, delay_load=False): 11 | super().__init__() 12 | 13 | self.is_loaded = False 14 | 15 | self.vision_tower_name = vision_tower 16 | self.select_layer = args.mm_vision_select_layer 17 | self.update_resolution = getattr( 18 | args, 'mm_vision_resolution', 256) 19 | self.vision_add_five_stage = getattr(args, 'vision_add_five_stage', 0) 20 | self.vision_five_stage_width = getattr(args, 'vision_five_stage_width', 1536) 21 | 22 | if not delay_load: 23 | self.load_model() 24 | else: 25 | print(f"deloy_load vision tower is: {self.vision_tower_name}") 26 | self.cfg_only = ConvNextConfig.from_pretrained( 27 | self.vision_tower_name) 28 | 29 | def load_model(self): 30 | print(f"entering load model, load {self.vision_tower_name}") 31 | self.image_processor = CLIPImageProcessor.from_pretrained( 32 | self.vision_tower_name) 33 | self.vision_tower = ConvNextModel.from_pretrained( 34 | self.vision_tower_name) 35 | self.vision_tower.requires_grad_(False) 36 | self.is_loaded = True 37 | 38 | if self.select_layer == -2: 39 | self.select_layer = -1 40 | self.vision_tower.encoder.stages[-1].layers.pop(-1) 41 | print( 42 | f'Last block removed, select layer changed to {self.select_layer}') 43 | 44 | if self.update_resolution > 256: 45 | self.set_crop_size(self.update_resolution) 46 | print( 47 | f'Crop size changed to {self.update_resolution}x{self.update_resolution}') 48 | 49 | if self.vision_add_five_stage != 0: 50 | self.add_stage(self.vision_add_five_stage, self.vision_five_stage_width) 51 | print( 52 | f'Added stage with width {self.vision_five_stage_width}') 53 | 54 | def forward(self, images): 55 | if type(images) is list: 56 | image_features = [] 57 | for image in images: 58 | # Get the embeddings of the image 59 | embedding_output = self.vision_tower.embeddings(image.unsqueeze(0)) 60 | 61 | # Get the image features 62 | image_feature = self.vision_tower.encoder(embedding_output, 63 | output_hidden_states=True, 64 | return_dict=True) 65 | image_feature = image_feature.hidden_states[-1].permute(0, 2, 3, 1) 66 | image_feature = image_feature.reshape(image_features.shape[0], -1, image_features.shape[3]).to(images.dtype) 67 | 68 | image_features.append(image_feature) 69 | else: 70 | embedding_output = self.vision_tower.embeddings(images) 71 | image_features = self.vision_tower.encoder(embedding_output, 72 | output_hidden_states=True, 73 | return_dict=True) 74 | image_features = image_features.hidden_states[-1].permute(0, 2, 3, 1) 75 | image_features = image_features.reshape(image_features.shape[0], -1, image_features.shape[3]).to(images.dtype) 76 | 77 | return image_features 78 | 79 | def make_layers_trainable_after_stage(self, stage_index, layer_index=0): 80 | for i, stage in enumerate(self.vision_tower.encoder.stages): 81 | if i == stage_index: 82 | if layer_index == 0: 83 | stage.downsampling_layer.requires_grad_(True) 84 | for idx, layer in enumerate(stage.layers): 85 | if idx >= layer_index: 86 | for param in layer.parameters(): 87 | param.requires_grad = True 88 | if i > stage_index: 89 | stage.downsampling_layer.requires_grad_(True) 90 | for layer in stage.layers: 91 | for param in layer.parameters(): 92 | param.requires_grad = True 93 | 94 | def set_crop_size(self, new_size): 95 | size_dict = {'height': new_size, 'width': new_size} 96 | self.image_processor.crop_size = size_dict 97 | self.image_processor.size = {"shortest_edge": new_size} 98 | self.vision_tower.config.image_size = new_size 99 | 100 | def add_stage(self, depths=3, hidden_dims=3072): 101 | self.vision_tower.encoder.stages.append(ConvNextStage(self.config, self.hidden_size, hidden_dims, depth=depths)) 102 | self.vision_tower.config.depths.append(depths) 103 | self.vision_tower.config.hidden_sizes.append(hidden_dims) 104 | self.vision_tower.config.stage_names.append('stage5') 105 | self.vision_tower.config.out_features = ['stage5'] 106 | self.vision_tower.config.out_indices = [5] 107 | self.vision_tower.config.num_stages += 1 108 | self.vision_tower.config._name_or_path = '' 109 | 110 | def save_config(self, path): 111 | self.vision_tower.config.save_pretrained(path) 112 | 113 | @property 114 | def dummy_feature(self): 115 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 116 | 117 | @property 118 | def dtype(self): 119 | return self.vision_tower.dtype 120 | 121 | @property 122 | def device(self): 123 | return self.vision_tower.device 124 | 125 | @property 126 | def config(self): 127 | if self.is_loaded: 128 | return self.vision_tower.config 129 | else: 130 | return self.cfg_only 131 | 132 | @property 133 | def hidden_size(self): 134 | return self.config.hidden_sizes[-1] 135 | 136 | @property 137 | def num_patches_per_side(self): 138 | return self.config.image_size // self.config.patch_size 139 | 140 | @property 141 | def num_patches(self): 142 | return (self.config.image_size // 32) ** 2 143 | 144 | @property 145 | def crop_size(self): 146 | return self.image_processor.crop_size 147 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/lknet_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | from transformers import CLIPImageProcessor 5 | from transformers import PretrainedConfig 6 | import os 7 | from .unireplknet.unireplknet_encoder import unireplknet_l_plus 8 | 9 | 10 | class LKNetConfig(PretrainedConfig): 11 | model_type = "lknet" 12 | 13 | def __init__( 14 | self, 15 | in_chans=3, 16 | image_size=256, 17 | num_classes=1000, 18 | depths=(3, 3, 27, 3, 6), 19 | dims=(192, 384, 768, 1536, 3072), 20 | drop_path_rate=0.0, 21 | layer_scale_init_value=1e-6, 22 | head_init_scale=1.0, 23 | kernel_sizes=None, 24 | deploy=False, 25 | with_cp=True, 26 | init_cfg=None, 27 | attempt_use_lk_impl=True, 28 | use_sync_bn=False, 29 | **kwargs, 30 | ): 31 | super().__init__(**kwargs) 32 | self.in_chans = in_chans 33 | self.image_size = image_size 34 | self.num_classes = num_classes 35 | self.depths = depths 36 | self.dims = dims 37 | self.drop_path_rate = drop_path_rate 38 | self.layer_scale_init_value = layer_scale_init_value 39 | self.head_init_scale = head_init_scale 40 | self.kernel_sizes = kernel_sizes 41 | self.deploy = deploy 42 | self.with_cp = with_cp 43 | self.init_cfg = init_cfg 44 | self.attempt_use_lk_impl = attempt_use_lk_impl 45 | self.use_sync_bn = use_sync_bn 46 | 47 | 48 | unireplknet_l_plus_config = { 49 | "depths": (3, 3, 27, 3, 6), 50 | "kernel_sizes": ( 51 | (3, 3, 3), 52 | (13, 13, 13), 53 | ( 54 | 13, 55 | 3, 56 | 3, 57 | 13, 58 | 3, 59 | 3, 60 | 13, 61 | 3, 62 | 3, 63 | 13, 64 | 3, 65 | 3, 66 | 13, 67 | 3, 68 | 3, 69 | 13, 70 | 3, 71 | 3, 72 | 13, 73 | 3, 74 | 3, 75 | 13, 76 | 3, 77 | 3, 78 | 13, 79 | 3, 80 | 3, 81 | ), 82 | (13, 13, 13), 83 | (3, 3, 3, 3, 3, 3), 84 | ), 85 | } 86 | 87 | 88 | class LKNetCLIPVisionTower(nn.Module): 89 | def __init__(self, vision_tower, args, delay_load=False): 90 | super().__init__() 91 | 92 | self.is_loaded = False 93 | 94 | self.vision_tower_name = vision_tower 95 | self.select_layer = args.mm_vision_select_layer 96 | self.update_resolution = getattr(args, "mm_vision_resolution", 256) 97 | self.cfg_only = LKNetConfig.from_dict(unireplknet_l_plus_config) 98 | 99 | if not delay_load: 100 | self.load_model() 101 | 102 | def load_model(self): 103 | print(f"entering load model, load {self.vision_tower_name}") 104 | self.image_processor = CLIPImageProcessor.from_pretrained( 105 | os.path.dirname(self.vision_tower_name) 106 | ) 107 | self.vision_tower = unireplknet_l_plus() 108 | self.vision_tower.config = LKNetConfig.from_dict(unireplknet_l_plus_config) 109 | ckpt = torch.load(self.vision_tower_name) 110 | del ckpt["norm.weight"] 111 | del ckpt["norm.bias"] 112 | missing_keys, unexpected_keys = self.vision_tower.load_state_dict( 113 | ckpt, strict=False 114 | ) 115 | print("Loaded CLIP Pretrained Models") 116 | print( 117 | f"missing keys are {missing_keys}\n unexpected keys are {unexpected_keys}" 118 | ) 119 | 120 | self.vision_tower.requires_grad_(False) 121 | self.is_loaded = True 122 | 123 | if self.update_resolution > 256: 124 | self.set_crop_size(self.update_resolution) 125 | print( 126 | f"Crop size changed to {self.update_resolution}x{self.update_resolution}" 127 | ) 128 | self.make_layers_trainable_after_stage(4) 129 | 130 | def forward(self, images): 131 | if type(images) is list: 132 | image_features = [] 133 | for image in images: 134 | x = image 135 | for stage_idx in range(5): 136 | x = self.vision_tower.downsample_layers[stage_idx](x) 137 | x = self.vision_tower.stages[stage_idx](x) 138 | image_features = x.permute(0, 2, 3, 1) 139 | image_features = image_features.reshape(x.shape[0], -1, x.shape[1]).to(image.dtype) 140 | image_features.append(x) 141 | else: 142 | x = images 143 | for stage_idx in range(5): 144 | x = self.vision_tower.downsample_layers[stage_idx](x) 145 | x = self.vision_tower.stages[stage_idx](x) 146 | image_features = x.permute(0, 2, 3, 1) 147 | image_features = image_features.reshape(x.shape[0], -1, x.shape[1]).to(images.dtype) 148 | return image_features 149 | 150 | def make_layers_trainable_after_stage(self, stage_index): 151 | for i, stage in enumerate(self.vision_tower.stages): 152 | if i >= stage_index: 153 | for param in stage.parameters(): 154 | param.requires_grad = True 155 | for i, stage in enumerate(self.vision_tower.downsample_layers): 156 | if i >= stage_index: 157 | for param in stage.parameters(): 158 | param.requires_grad = True 159 | self.print_trainable_parameters() 160 | 161 | def print_trainable_parameters(self): 162 | print("Trainable status of each stage:") 163 | for i, stage in enumerate(self.vision_tower.stages): 164 | trainable = all(param.requires_grad for param in stage.parameters()) 165 | print(f"Stage {i}: {'Trainable' if trainable else 'Not Trainable'}") 166 | 167 | print("\nTrainable status of each downsampling layer:") 168 | for i, downsample_layer in enumerate(self.vision_tower.downsample_layers): 169 | trainable = all(param.requires_grad for param in downsample_layer.parameters()) 170 | print(f"Downsampling Layer {i}: {'Trainable' if trainable else 'Not Trainable'}") 171 | 172 | def set_crop_size(self, new_size): 173 | size_dict = {"height": new_size, "width": new_size} 174 | self.image_processor.crop_size = size_dict 175 | self.image_processor.size = {"shortest_edge": new_size} 176 | self.vision_tower.config.image_size = new_size 177 | self.config.image_size = new_size 178 | 179 | def save_config(self, path): 180 | self.config.save_pretrained(path) 181 | 182 | @property 183 | def dummy_feature(self): 184 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 185 | 186 | @property 187 | def dtype(self): 188 | return self.vision_tower.downsample_layers[0][0].weight.dtype 189 | 190 | @property 191 | def device(self): 192 | return self.vision_tower.downsample_layers[0][0].weight.device 193 | 194 | @property 195 | def config(self): 196 | return self.cfg_only 197 | 198 | @property 199 | def hidden_size(self): 200 | return self.config.dims[-1] 201 | 202 | @property 203 | def num_patches_per_side(self): 204 | return self.config.image_size // 2 ** (len(self.config.depths) + 1) 205 | 206 | @property 207 | def num_patches(self): 208 | return (self.config.image_size // 2 ** (len(self.config.depths) + 1)) ** 2 209 | 210 | @property 211 | def crop_size(self): 212 | return self.image_processor.crop_size 213 | 214 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/siglip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | 5 | from transformers import SiglipVisionConfig, SiglipVisionModel, SiglipImageProcessor 6 | 7 | 8 | class SiglipVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_name = vision_tower 15 | self.select_layer = args.mm_vision_select_layer 16 | self.select_feature = getattr( 17 | args, 'mm_vision_select_feature', 'patch') 18 | 19 | if not delay_load: 20 | self.load_model() 21 | else: 22 | self.cfg_only = SiglipVisionConfig.from_pretrained( 23 | self.vision_tower_name) 24 | 25 | def get_input_embeddings(self) -> nn.Module: 26 | return self.vision_tower.embeddings.patch_embedding 27 | 28 | def load_model(self): 29 | self.image_processor = SiglipImageProcessor.from_pretrained( 30 | self.vision_tower_name) 31 | self.vision_tower = SiglipVisionModel.from_pretrained( 32 | self.vision_tower_name) 33 | self.vision_tower.requires_grad_(False) 34 | 35 | self.is_loaded = True 36 | 37 | def print_trainable_parameters(self): 38 | for i, layer in enumerate(self.vision_tower.vision_model.encoder.layers): 39 | is_trainable = any( 40 | param.requires_grad for param in layer.parameters()) 41 | print(f"ViT Layer {i} is trainable: {is_trainable}") 42 | 43 | def feature_select(self, image_forward_outs): 44 | image_features = image_forward_outs.hidden_states[self.select_layer] 45 | if self.select_feature == 'patch': 46 | image_features = image_features 47 | elif self.select_feature == 'cls_patch': 48 | image_features = image_features 49 | else: 50 | raise ValueError( 51 | f'Unexpected select feature: {self.select_feature}') 52 | return image_features 53 | 54 | def forward(self, images): 55 | if type(images) is list: 56 | image_features = [] 57 | for image in images: 58 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 59 | output_hidden_states=True) 60 | image_feature = self.feature_select( 61 | image_forward_out).to(image.dtype) 62 | image_features.append(image_feature) 63 | else: 64 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), 65 | output_hidden_states=True) 66 | image_features = self.feature_select( 67 | image_forward_outs).to(images.dtype) 68 | 69 | return image_features 70 | 71 | @property 72 | def dummy_feature(self): 73 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 74 | 75 | @property 76 | def dtype(self): 77 | return list(self.vision_tower.parameters())[0].dtype 78 | 79 | @property 80 | def device(self): 81 | return list(self.vision_tower.parameters())[0].device 82 | 83 | @property 84 | def config(self): 85 | if self.is_loaded: 86 | return self.vision_tower.config 87 | else: 88 | return self.cfg_only 89 | 90 | @property 91 | def hidden_size(self): 92 | return self.config.hidden_size 93 | 94 | @property 95 | def num_patches_per_side(self): 96 | return self.config.image_size // self.config.patch_size 97 | 98 | @property 99 | def num_patches(self): 100 | return (self.config.image_size // self.config.patch_size) ** 2 101 | 102 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/unireplknet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/model/multimodal_encoder/unireplknet/__init__.py -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | 29 | def forward(self, x): 30 | x = self.pre_norm(x) 31 | return x + self.proj(x) 32 | 33 | 34 | def build_vision_projector(config, delay_load=False, **kwargs): 35 | projector_type = getattr(config, 'mm_projector_type', 'linear') 36 | 37 | if projector_type == 'linear': 38 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 39 | 40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 41 | if mlp_gelu_match: 42 | mlp_depth = int(mlp_gelu_match.group(1)) 43 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 44 | for _ in range(1, mlp_depth): 45 | modules.append(nn.GELU()) 46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 47 | return nn.Sequential(*modules) 48 | 49 | if projector_type == 'identity': 50 | return IdentityMap() 51 | 52 | raise ValueError(f'Unknown projector type: {projector_type}') 53 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if 'llama-2' in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "v1" in model_name.lower(): 37 | conv_mode = "llava_v1" 38 | elif "mpt" in model_name.lower(): 39 | conv_mode = "mpt" 40 | else: 41 | conv_mode = "llava_v0" 42 | 43 | if args.conv_mode is not None and conv_mode != args.conv_mode: 44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 45 | else: 46 | args.conv_mode = conv_mode 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | if "mpt" in model_name.lower(): 50 | roles = ('user', 'assistant') 51 | else: 52 | roles = conv.roles 53 | 54 | image = load_image(args.image_file) 55 | # Similar operation in model_worker.py 56 | image_tensor = process_images([image], image_processor, model.config) 57 | if type(image_tensor) is list: 58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 59 | else: 60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 61 | 62 | while True: 63 | try: 64 | inp = input(f"{roles[0]}: ") 65 | except EOFError: 66 | inp = "" 67 | if not inp: 68 | print("exit...") 69 | break 70 | 71 | print(f"{roles[1]}: ", end="") 72 | 73 | if image is not None: 74 | # first message 75 | if model.config.mm_use_im_start_end: 76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 77 | else: 78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 79 | conv.append_message(conv.roles[0], inp) 80 | image = None 81 | else: 82 | # later messages 83 | conv.append_message(conv.roles[0], inp) 84 | conv.append_message(conv.roles[1], None) 85 | prompt = conv.get_prompt() 86 | 87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 92 | 93 | with torch.inference_mode(): 94 | output_ids = model.generate( 95 | input_ids, 96 | images=image_tensor, 97 | do_sample=True if args.temperature > 0 else False, 98 | temperature=args.temperature, 99 | max_new_tokens=args.max_new_tokens, 100 | streamer=streamer, 101 | use_cache=True, 102 | stopping_criteria=[stopping_criteria]) 103 | 104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 105 | conv.messages[-1][-1] = outputs 106 | 107 | if args.debug: 108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 109 | 110 | 111 | if __name__ == "__main__": 112 | parser = argparse.ArgumentParser() 113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 114 | parser.add_argument("--model-base", type=str, default=None) 115 | parser.add_argument("--image-file", type=str, required=True) 116 | parser.add_argument("--device", type=str, default="cuda") 117 | parser.add_argument("--conv-mode", type=str, default=None) 118 | parser.add_argument("--temperature", type=float, default=0.2) 119 | parser.add_argument("--max-new-tokens", type=int, default=512) 120 | parser.add_argument("--load-8bit", action="store_true") 121 | parser.add_argument("--load-4bit", action="store_true") 122 | parser.add_argument("--debug", action="store_true") 123 | args = parser.parse_args() 124 | main(args) 125 | -------------------------------------------------------------------------------- /llava/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from llava.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,)) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 299 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/gradio_web_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import time 6 | 7 | import gradio as gr 8 | import requests 9 | 10 | from llava.conversation import (default_conversation, conv_templates, 11 | SeparatorStyle) 12 | from llava.constants import LOGDIR 13 | from llava.utils import (build_logger, server_error_msg, 14 | violates_moderation, moderation_msg) 15 | import hashlib 16 | 17 | 18 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 19 | 20 | headers = {"User-Agent": "LLaVA Client"} 21 | 22 | no_change_btn = gr.Button.update() 23 | enable_btn = gr.Button.update(interactive=True) 24 | disable_btn = gr.Button.update(interactive=False) 25 | 26 | priority = { 27 | "vicuna-13b": "aaaaaaa", 28 | "koala-13b": "aaaaaab", 29 | } 30 | 31 | 32 | def get_conv_log_filename(): 33 | t = datetime.datetime.now() 34 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") 35 | return name 36 | 37 | 38 | def get_model_list(): 39 | ret = requests.post(args.controller_url + "/refresh_all_workers") 40 | assert ret.status_code == 200 41 | ret = requests.post(args.controller_url + "/list_models") 42 | models = ret.json()["models"] 43 | models.sort(key=lambda x: priority.get(x, x)) 44 | logger.info(f"Models: {models}") 45 | return models 46 | 47 | 48 | get_window_url_params = """ 49 | function() { 50 | const params = new URLSearchParams(window.location.search); 51 | url_params = Object.fromEntries(params); 52 | console.log(url_params); 53 | return url_params; 54 | } 55 | """ 56 | 57 | 58 | def load_demo(url_params, request: gr.Request): 59 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") 60 | 61 | dropdown_update = gr.Dropdown.update(visible=True) 62 | if "model" in url_params: 63 | model = url_params["model"] 64 | if model in models: 65 | dropdown_update = gr.Dropdown.update( 66 | value=model, visible=True) 67 | 68 | state = default_conversation.copy() 69 | return state, dropdown_update 70 | 71 | 72 | def load_demo_refresh_model_list(request: gr.Request): 73 | logger.info(f"load_demo. ip: {request.client.host}") 74 | models = get_model_list() 75 | state = default_conversation.copy() 76 | dropdown_update = gr.Dropdown.update( 77 | choices=models, 78 | value=models[0] if len(models) > 0 else "" 79 | ) 80 | return state, dropdown_update 81 | 82 | 83 | def vote_last_response(state, vote_type, model_selector, request: gr.Request): 84 | with open(get_conv_log_filename(), "a") as fout: 85 | data = { 86 | "tstamp": round(time.time(), 4), 87 | "type": vote_type, 88 | "model": model_selector, 89 | "state": state.dict(), 90 | "ip": request.client.host, 91 | } 92 | fout.write(json.dumps(data) + "\n") 93 | 94 | 95 | def upvote_last_response(state, model_selector, request: gr.Request): 96 | logger.info(f"upvote. ip: {request.client.host}") 97 | vote_last_response(state, "upvote", model_selector, request) 98 | return ("",) + (disable_btn,) * 3 99 | 100 | 101 | def downvote_last_response(state, model_selector, request: gr.Request): 102 | logger.info(f"downvote. ip: {request.client.host}") 103 | vote_last_response(state, "downvote", model_selector, request) 104 | return ("",) + (disable_btn,) * 3 105 | 106 | 107 | def flag_last_response(state, model_selector, request: gr.Request): 108 | logger.info(f"flag. ip: {request.client.host}") 109 | vote_last_response(state, "flag", model_selector, request) 110 | return ("",) + (disable_btn,) * 3 111 | 112 | 113 | def regenerate(state, image_process_mode, request: gr.Request): 114 | logger.info(f"regenerate. ip: {request.client.host}") 115 | state.messages[-1][-1] = None 116 | prev_human_msg = state.messages[-2] 117 | if type(prev_human_msg[1]) in (tuple, list): 118 | prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) 119 | state.skip_next = False 120 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 121 | 122 | 123 | def clear_history(request: gr.Request): 124 | logger.info(f"clear_history. ip: {request.client.host}") 125 | state = default_conversation.copy() 126 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 127 | 128 | 129 | def add_text(state, text, image, image_process_mode, request: gr.Request): 130 | logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") 131 | if len(text) <= 0 and image is None: 132 | state.skip_next = True 133 | return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 134 | if args.moderate: 135 | flagged = violates_moderation(text) 136 | if flagged: 137 | state.skip_next = True 138 | return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( 139 | no_change_btn,) * 5 140 | 141 | text = text[:1536] # Hard cut-off 142 | if image is not None: 143 | text = text[:1200] # Hard cut-off for images 144 | if '' not in text: 145 | # text = '' + text 146 | text = text + '\n' 147 | text = (text, image, image_process_mode) 148 | if len(state.get_images(return_pil=True)) > 0: 149 | state = default_conversation.copy() 150 | state.append_message(state.roles[0], text) 151 | state.append_message(state.roles[1], None) 152 | state.skip_next = False 153 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 154 | 155 | 156 | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): 157 | logger.info(f"http_bot. ip: {request.client.host}") 158 | start_tstamp = time.time() 159 | model_name = model_selector 160 | 161 | if state.skip_next: 162 | # This generate call is skipped due to invalid inputs 163 | yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 164 | return 165 | 166 | if len(state.messages) == state.offset + 2: 167 | # First round of conversation 168 | if "llava" in model_name.lower(): 169 | if 'llama-2' in model_name.lower(): 170 | template_name = "llava_llama_2" 171 | elif "v1" in model_name.lower(): 172 | if 'mmtag' in model_name.lower(): 173 | template_name = "v1_mmtag" 174 | elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): 175 | template_name = "v1_mmtag" 176 | else: 177 | template_name = "llava_v1" 178 | elif "mpt" in model_name.lower(): 179 | template_name = "mpt" 180 | else: 181 | if 'mmtag' in model_name.lower(): 182 | template_name = "v0_mmtag" 183 | elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): 184 | template_name = "v0_mmtag" 185 | else: 186 | template_name = "llava_v0" 187 | elif "mpt" in model_name: 188 | template_name = "mpt_text" 189 | elif "llama-2" in model_name: 190 | template_name = "llama_2" 191 | else: 192 | template_name = "vicuna_v1" 193 | new_state = conv_templates[template_name].copy() 194 | new_state.append_message(new_state.roles[0], state.messages[-2][1]) 195 | new_state.append_message(new_state.roles[1], None) 196 | state = new_state 197 | 198 | # Query worker address 199 | controller_url = args.controller_url 200 | ret = requests.post(controller_url + "/get_worker_address", 201 | json={"model": model_name}) 202 | worker_addr = ret.json()["address"] 203 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") 204 | 205 | # No available worker 206 | if worker_addr == "": 207 | state.messages[-1][-1] = server_error_msg 208 | yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) 209 | return 210 | 211 | # Construct prompt 212 | prompt = state.get_prompt() 213 | 214 | all_images = state.get_images(return_pil=True) 215 | all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] 216 | for image, hash in zip(all_images, all_image_hash): 217 | t = datetime.datetime.now() 218 | filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") 219 | if not os.path.isfile(filename): 220 | os.makedirs(os.path.dirname(filename), exist_ok=True) 221 | image.save(filename) 222 | 223 | # Make requests 224 | pload = { 225 | "model": model_name, 226 | "prompt": prompt, 227 | "temperature": float(temperature), 228 | "top_p": float(top_p), 229 | "max_new_tokens": min(int(max_new_tokens), 1536), 230 | "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, 231 | "images": f'List of {len(state.get_images())} images: {all_image_hash}', 232 | } 233 | logger.info(f"==== request ====\n{pload}") 234 | 235 | pload['images'] = state.get_images() 236 | 237 | state.messages[-1][-1] = "▌" 238 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 239 | 240 | try: 241 | # Stream output 242 | response = requests.post(worker_addr + "/worker_generate_stream", 243 | headers=headers, json=pload, stream=True, timeout=10) 244 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 245 | if chunk: 246 | data = json.loads(chunk.decode()) 247 | if data["error_code"] == 0: 248 | output = data["text"][len(prompt):].strip() 249 | state.messages[-1][-1] = output + "▌" 250 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 251 | else: 252 | output = data["text"] + f" (error_code: {data['error_code']})" 253 | state.messages[-1][-1] = output 254 | yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) 255 | return 256 | time.sleep(0.03) 257 | except requests.exceptions.RequestException as e: 258 | state.messages[-1][-1] = server_error_msg 259 | yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) 260 | return 261 | 262 | state.messages[-1][-1] = state.messages[-1][-1][:-1] 263 | yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 264 | 265 | finish_tstamp = time.time() 266 | logger.info(f"{output}") 267 | 268 | with open(get_conv_log_filename(), "a") as fout: 269 | data = { 270 | "tstamp": round(finish_tstamp, 4), 271 | "type": "chat", 272 | "model": model_name, 273 | "start": round(start_tstamp, 4), 274 | "finish": round(finish_tstamp, 4), 275 | "state": state.dict(), 276 | "images": all_image_hash, 277 | "ip": request.client.host, 278 | } 279 | fout.write(json.dumps(data) + "\n") 280 | 281 | title_markdown = (""" 282 | # 🌋 LLaVA: Large Language and Vision Assistant 283 | [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)] 284 | """) 285 | 286 | tos_markdown = (""" 287 | ### Terms of use 288 | By using this service, users are required to agree to the following terms: 289 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. 290 | Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator. 291 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality. 292 | """) 293 | 294 | 295 | learn_more_markdown = (""" 296 | ### License 297 | The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. 298 | """) 299 | 300 | block_css = """ 301 | 302 | #buttons button { 303 | min-width: min(120px,100%); 304 | } 305 | 306 | """ 307 | 308 | def build_demo(embed_mode): 309 | textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) 310 | with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: 311 | state = gr.State() 312 | 313 | if not embed_mode: 314 | gr.Markdown(title_markdown) 315 | 316 | with gr.Row(): 317 | with gr.Column(scale=3): 318 | with gr.Row(elem_id="model_selector_row"): 319 | model_selector = gr.Dropdown( 320 | choices=models, 321 | value=models[0] if len(models) > 0 else "", 322 | interactive=True, 323 | show_label=False, 324 | container=False) 325 | 326 | imagebox = gr.Image(type="pil") 327 | image_process_mode = gr.Radio( 328 | ["Crop", "Resize", "Pad", "Default"], 329 | value="Default", 330 | label="Preprocess for non-square image", visible=False) 331 | 332 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 333 | gr.Examples(examples=[ 334 | [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"], 335 | [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"], 336 | ], inputs=[imagebox, textbox]) 337 | 338 | with gr.Accordion("Parameters", open=False) as parameter_row: 339 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) 340 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) 341 | max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) 342 | 343 | with gr.Column(scale=8): 344 | chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550) 345 | with gr.Row(): 346 | with gr.Column(scale=8): 347 | textbox.render() 348 | with gr.Column(scale=1, min_width=50): 349 | submit_btn = gr.Button(value="Send", variant="primary") 350 | with gr.Row(elem_id="buttons") as button_row: 351 | upvote_btn = gr.Button(value="👍 Upvote", interactive=False) 352 | downvote_btn = gr.Button(value="👎 Downvote", interactive=False) 353 | flag_btn = gr.Button(value="⚠️ Flag", interactive=False) 354 | #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) 355 | regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) 356 | clear_btn = gr.Button(value="🗑️ Clear", interactive=False) 357 | 358 | if not embed_mode: 359 | gr.Markdown(tos_markdown) 360 | gr.Markdown(learn_more_markdown) 361 | url_params = gr.JSON(visible=False) 362 | 363 | # Register listeners 364 | btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] 365 | upvote_btn.click( 366 | upvote_last_response, 367 | [state, model_selector], 368 | [textbox, upvote_btn, downvote_btn, flag_btn], 369 | queue=False 370 | ) 371 | downvote_btn.click( 372 | downvote_last_response, 373 | [state, model_selector], 374 | [textbox, upvote_btn, downvote_btn, flag_btn], 375 | queue=False 376 | ) 377 | flag_btn.click( 378 | flag_last_response, 379 | [state, model_selector], 380 | [textbox, upvote_btn, downvote_btn, flag_btn], 381 | queue=False 382 | ) 383 | 384 | regenerate_btn.click( 385 | regenerate, 386 | [state, image_process_mode], 387 | [state, chatbot, textbox, imagebox] + btn_list, 388 | queue=False 389 | ).then( 390 | http_bot, 391 | [state, model_selector, temperature, top_p, max_output_tokens], 392 | [state, chatbot] + btn_list 393 | ) 394 | 395 | clear_btn.click( 396 | clear_history, 397 | None, 398 | [state, chatbot, textbox, imagebox] + btn_list, 399 | queue=False 400 | ) 401 | 402 | textbox.submit( 403 | add_text, 404 | [state, textbox, imagebox, image_process_mode], 405 | [state, chatbot, textbox, imagebox] + btn_list, 406 | queue=False 407 | ).then( 408 | http_bot, 409 | [state, model_selector, temperature, top_p, max_output_tokens], 410 | [state, chatbot] + btn_list 411 | ) 412 | 413 | submit_btn.click( 414 | add_text, 415 | [state, textbox, imagebox, image_process_mode], 416 | [state, chatbot, textbox, imagebox] + btn_list, 417 | queue=False 418 | ).then( 419 | http_bot, 420 | [state, model_selector, temperature, top_p, max_output_tokens], 421 | [state, chatbot] + btn_list 422 | ) 423 | 424 | if args.model_list_mode == "once": 425 | demo.load( 426 | load_demo, 427 | [url_params], 428 | [state, model_selector], 429 | _js=get_window_url_params, 430 | queue=False 431 | ) 432 | elif args.model_list_mode == "reload": 433 | demo.load( 434 | load_demo_refresh_model_list, 435 | None, 436 | [state, model_selector], 437 | queue=False 438 | ) 439 | else: 440 | raise ValueError(f"Unknown model list mode: {args.model_list_mode}") 441 | 442 | return demo 443 | 444 | 445 | if __name__ == "__main__": 446 | parser = argparse.ArgumentParser() 447 | parser.add_argument("--host", type=str, default="0.0.0.0") 448 | parser.add_argument("--port", type=int) 449 | parser.add_argument("--controller-url", type=str, default="http://localhost:21001") 450 | parser.add_argument("--concurrency-count", type=int, default=10) 451 | parser.add_argument("--model-list-mode", type=str, default="once", 452 | choices=["once", "reload"]) 453 | parser.add_argument("--share", action="store_true") 454 | parser.add_argument("--moderate", action="store_true") 455 | parser.add_argument("--embed", action="store_true") 456 | args = parser.parse_args() 457 | logger.info(f"args: {args}") 458 | 459 | models = get_model_list() 460 | 461 | logger.info(args) 462 | demo = build_demo(args.embed) 463 | demo.queue( 464 | concurrency_count=args.concurrency_count, 465 | api_open=False 466 | ).launch( 467 | server_name=args.host, 468 | server_port=args.port, 469 | share=args.share 470 | ) 471 | -------------------------------------------------------------------------------- /llava/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | from fastapi import FastAPI, Request, BackgroundTasks 12 | from fastapi.responses import StreamingResponse 13 | import requests 14 | import torch 15 | import uvicorn 16 | from functools import partial 17 | 18 | from llava.constants import WORKER_HEART_BEAT_INTERVAL 19 | from llava.utils import (build_logger, server_error_msg, 20 | pretty_print_semaphore) 21 | from llava.model.builder import load_pretrained_model 22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria 23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | from transformers import TextIteratorStreamer 25 | from threading import Thread 26 | 27 | 28 | GB = 1 << 30 29 | 30 | worker_id = str(uuid.uuid4())[:6] 31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 32 | global_counter = 0 33 | 34 | model_semaphore = None 35 | 36 | 37 | def heart_beat_worker(controller): 38 | 39 | while True: 40 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 41 | controller.send_heart_beat() 42 | 43 | 44 | class ModelWorker: 45 | def __init__(self, controller_addr, worker_addr, 46 | worker_id, no_register, 47 | model_path, model_base, model_name, 48 | load_8bit, load_4bit, device): 49 | self.controller_addr = controller_addr 50 | self.worker_addr = worker_addr 51 | self.worker_id = worker_id 52 | if model_path.endswith("/"): 53 | model_path = model_path[:-1] 54 | if model_name is None: 55 | model_paths = model_path.split("/") 56 | if model_paths[-1].startswith('checkpoint-'): 57 | self.model_name = model_paths[-2] + "_" + model_paths[-1] 58 | else: 59 | self.model_name = model_paths[-1] 60 | else: 61 | self.model_name = model_name 62 | 63 | self.device = device 64 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 65 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( 66 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device) 67 | self.is_multimodal = 'llava' in self.model_name.lower() 68 | 69 | if not no_register: 70 | self.register_to_controller() 71 | self.heart_beat_thread = threading.Thread( 72 | target=heart_beat_worker, args=(self,)) 73 | self.heart_beat_thread.start() 74 | 75 | def register_to_controller(self): 76 | logger.info("Register to controller") 77 | 78 | url = self.controller_addr + "/register_worker" 79 | data = { 80 | "worker_name": self.worker_addr, 81 | "check_heart_beat": True, 82 | "worker_status": self.get_status() 83 | } 84 | r = requests.post(url, json=data) 85 | assert r.status_code == 200 86 | 87 | def send_heart_beat(self): 88 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 89 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 90 | f"global_counter: {global_counter}") 91 | 92 | url = self.controller_addr + "/receive_heart_beat" 93 | 94 | while True: 95 | try: 96 | ret = requests.post(url, json={ 97 | "worker_name": self.worker_addr, 98 | "queue_length": self.get_queue_length()}, timeout=5) 99 | exist = ret.json()["exist"] 100 | break 101 | except requests.exceptions.RequestException as e: 102 | logger.error(f"heart beat error: {e}") 103 | time.sleep(5) 104 | 105 | if not exist: 106 | self.register_to_controller() 107 | 108 | def get_queue_length(self): 109 | if model_semaphore is None: 110 | return 0 111 | else: 112 | return args.limit_model_concurrency - model_semaphore._value + (len( 113 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 114 | 115 | def get_status(self): 116 | return { 117 | "model_names": [self.model_name], 118 | "speed": 1, 119 | "queue_length": self.get_queue_length(), 120 | } 121 | 122 | @torch.inference_mode() 123 | def generate_stream(self, params): 124 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 125 | 126 | prompt = params["prompt"] 127 | ori_prompt = prompt 128 | images = params.get("images", None) 129 | num_image_tokens = 0 130 | if images is not None and len(images) > 0 and self.is_multimodal: 131 | if len(images) > 0: 132 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 133 | raise ValueError("Number of images does not match number of tokens in prompt") 134 | 135 | images = [load_image_from_base64(image) for image in images] 136 | images = process_images(images, image_processor, model.config) 137 | 138 | if type(images) is list: 139 | images = [image.to(self.model.device, dtype=torch.float16) for image in images] 140 | else: 141 | images = images.to(self.model.device, dtype=torch.float16) 142 | 143 | replace_token = DEFAULT_IMAGE_TOKEN 144 | if getattr(self.model.config, 'mm_use_im_start_end', False): 145 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 146 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 147 | 148 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches 149 | else: 150 | images = None 151 | image_args = {"images": images} 152 | else: 153 | images = None 154 | image_args = {} 155 | 156 | temperature = float(params.get("temperature", 1.0)) 157 | top_p = float(params.get("top_p", 1.0)) 158 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 159 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 160 | stop_str = params.get("stop", None) 161 | do_sample = True if temperature > 0.001 else False 162 | 163 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) 164 | keywords = [stop_str] 165 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 166 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 167 | 168 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 169 | 170 | if max_new_tokens < 1: 171 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 172 | return 173 | 174 | thread = Thread(target=model.generate, kwargs=dict( 175 | inputs=input_ids, 176 | do_sample=do_sample, 177 | temperature=temperature, 178 | top_p=top_p, 179 | max_new_tokens=max_new_tokens, 180 | streamer=streamer, 181 | stopping_criteria=[stopping_criteria], 182 | use_cache=True, 183 | **image_args 184 | )) 185 | thread.start() 186 | 187 | generated_text = ori_prompt 188 | for new_text in streamer: 189 | generated_text += new_text 190 | if generated_text.endswith(stop_str): 191 | generated_text = generated_text[:-len(stop_str)] 192 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 193 | 194 | def generate_stream_gate(self, params): 195 | try: 196 | for x in self.generate_stream(params): 197 | yield x 198 | except ValueError as e: 199 | print("Caught ValueError:", e) 200 | ret = { 201 | "text": server_error_msg, 202 | "error_code": 1, 203 | } 204 | yield json.dumps(ret).encode() + b"\0" 205 | except torch.cuda.CudaError as e: 206 | print("Caught torch.cuda.CudaError:", e) 207 | ret = { 208 | "text": server_error_msg, 209 | "error_code": 1, 210 | } 211 | yield json.dumps(ret).encode() + b"\0" 212 | except Exception as e: 213 | print("Caught Unknown Error", e) 214 | ret = { 215 | "text": server_error_msg, 216 | "error_code": 1, 217 | } 218 | yield json.dumps(ret).encode() + b"\0" 219 | 220 | 221 | app = FastAPI() 222 | 223 | 224 | def release_model_semaphore(fn=None): 225 | model_semaphore.release() 226 | if fn is not None: 227 | fn() 228 | 229 | 230 | @app.post("/worker_generate_stream") 231 | async def generate_stream(request: Request): 232 | global model_semaphore, global_counter 233 | global_counter += 1 234 | params = await request.json() 235 | 236 | if model_semaphore is None: 237 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 238 | await model_semaphore.acquire() 239 | worker.send_heart_beat() 240 | generator = worker.generate_stream_gate(params) 241 | background_tasks = BackgroundTasks() 242 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 243 | return StreamingResponse(generator, background=background_tasks) 244 | 245 | 246 | @app.post("/worker_get_status") 247 | async def get_status(request: Request): 248 | return worker.get_status() 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument("--host", type=str, default="localhost") 254 | parser.add_argument("--port", type=int, default=21002) 255 | parser.add_argument("--worker-address", type=str, 256 | default="http://localhost:21002") 257 | parser.add_argument("--controller-address", type=str, 258 | default="http://localhost:21001") 259 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 260 | parser.add_argument("--model-base", type=str, default=None) 261 | parser.add_argument("--model-name", type=str) 262 | parser.add_argument("--device", type=str, default="cuda") 263 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 264 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 265 | parser.add_argument("--stream-interval", type=int, default=1) 266 | parser.add_argument("--no-register", action="store_true") 267 | parser.add_argument("--load-8bit", action="store_true") 268 | parser.add_argument("--load-4bit", action="store_true") 269 | args = parser.parse_args() 270 | logger.info(f"args: {args}") 271 | 272 | if args.multi_modal: 273 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.") 274 | 275 | worker = ModelWorker(args.controller_address, 276 | args.worker_address, 277 | worker_id, 278 | args.no_register, 279 | args.model_path, 280 | args.model_base, 281 | args.model_name, 282 | args.load_8bit, 283 | args.load_4bit, 284 | args.device) 285 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 286 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from torch.utils.data import Sampler 6 | 7 | from transformers import Trainer 8 | from transformers.trainer import ( 9 | is_sagemaker_mp_enabled, 10 | get_parameter_names, 11 | has_length, 12 | ALL_LAYERNORM_LAYERS, 13 | logger, 14 | ) 15 | from typing import List, Optional 16 | 17 | 18 | def maybe_zero_3(param, ignore_status=False, name=None): 19 | from deepspeed import zero 20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 21 | if hasattr(param, "ds_id"): 22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 23 | if not ignore_status: 24 | print(name, 'no ignore status') 25 | with zero.GatheredParameters([param]): 26 | param = param.data.detach().cpu().clone() 27 | else: 28 | param = param.detach().cpu().clone() 29 | return param 30 | 31 | 32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 33 | to_return = {k: t for k, t in named_params if any( 34 | key_match in k for key_match in keys_to_match)} 35 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() 36 | for k, v in to_return.items()} 37 | return to_return 38 | 39 | 40 | def split_to_even_chunks(indices, lengths, num_chunks): 41 | """ 42 | Split a list of indices into `chunks` chunks of roughly equal lengths. 43 | """ 44 | 45 | if len(indices) % num_chunks != 0: 46 | return [indices[i::num_chunks] for i in range(num_chunks)] 47 | 48 | num_indices_per_chunk = len(indices) // num_chunks 49 | 50 | chunks = [[] for _ in range(num_chunks)] 51 | chunks_lengths = [0 for _ in range(num_chunks)] 52 | for index in indices: 53 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 54 | chunks[shortest_chunk].append(index) 55 | chunks_lengths[shortest_chunk] += lengths[index] 56 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 57 | chunks_lengths[shortest_chunk] = float("inf") 58 | 59 | return chunks 60 | 61 | 62 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 63 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 64 | assert all(l != 0 for l in lengths), "Should not have zero length." 65 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 66 | # all samples are in the same modality 67 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 68 | mm_indices, mm_lengths = zip(*[(i, l) 69 | for i, l in enumerate(lengths) if l > 0]) 70 | lang_indices, lang_lengths = zip( 71 | *[(i, -l) for i, l in enumerate(lengths) if l < 0]) 72 | 73 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices( 74 | mm_lengths, batch_size, world_size, generator=None)] 75 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices( 76 | lang_lengths, batch_size, world_size, generator=None)] 77 | megabatch_size = world_size * batch_size 78 | mm_megabatches = [mm_shuffle[i: i + megabatch_size] 79 | for i in range(0, len(mm_shuffle), megabatch_size)] 80 | lang_megabatches = [lang_shuffle[i: i + megabatch_size] 81 | for i in range(0, len(lang_shuffle), megabatch_size)] 82 | 83 | last_mm = mm_megabatches[-1] 84 | last_lang = lang_megabatches[-1] 85 | additional_batch = last_mm + last_lang 86 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 87 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 88 | megabatches = [megabatches[i] for i in megabatch_indices] 89 | 90 | if len(additional_batch) > 0: 91 | megabatches.append(sorted(additional_batch)) 92 | 93 | return [i for megabatch in megabatches for i in megabatch] 94 | 95 | 96 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 97 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 98 | indices = torch.randperm(len(lengths), generator=generator) 99 | megabatch_size = world_size * batch_size 100 | megabatches = [indices[i: i + megabatch_size].tolist() 101 | for i in range(0, len(lengths), megabatch_size)] 102 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], 103 | reverse=True) for megabatch in megabatches] 104 | megabatches = [split_to_even_chunks( 105 | megabatch, lengths, world_size) for megabatch in megabatches] 106 | 107 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 108 | 109 | 110 | class LengthGroupedSampler(Sampler): 111 | r""" 112 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 113 | keeping a bit of randomness. 114 | """ 115 | 116 | def __init__( 117 | self, 118 | batch_size: int, 119 | world_size: int, 120 | lengths: Optional[List[int]] = None, 121 | generator=None, 122 | group_by_modality: bool = False, 123 | ): 124 | if lengths is None: 125 | raise ValueError("Lengths must be provided.") 126 | 127 | self.batch_size = batch_size 128 | self.world_size = world_size 129 | self.lengths = lengths 130 | self.generator = generator 131 | self.group_by_modality = group_by_modality 132 | 133 | def __len__(self): 134 | return len(self.lengths) 135 | 136 | def __iter__(self): 137 | if self.group_by_modality: 138 | indices = get_modality_length_grouped_indices( 139 | self.lengths, self.batch_size, self.world_size, generator=self.generator) 140 | else: 141 | indices = get_length_grouped_indices( 142 | self.lengths, self.batch_size, self.world_size, generator=self.generator) 143 | return iter(indices) 144 | 145 | 146 | class VisionLLaVATrainer(Trainer): 147 | 148 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 149 | if self.train_dataset is None or not has_length(self.train_dataset): 150 | return None 151 | 152 | if self.args.group_by_modality_length: 153 | lengths = self.train_dataset.modality_lengths 154 | return LengthGroupedSampler( 155 | self.args.train_batch_size, 156 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 157 | lengths=lengths, 158 | group_by_modality=True, 159 | ) 160 | else: 161 | return super()._get_train_sampler() 162 | 163 | def create_optimizer(self): 164 | """ 165 | Setup the optimizer. 166 | 167 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 168 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 169 | """ 170 | if is_sagemaker_mp_enabled(): 171 | return super().create_optimizer() 172 | 173 | opt_model = self.model 174 | 175 | if self.optimizer is None: 176 | decay_parameters = get_parameter_names( 177 | opt_model, ALL_LAYERNORM_LAYERS) 178 | decay_parameters = [ 179 | name for name in decay_parameters if "bias" not in name] 180 | if self.args.mm_projector_lr is not None and self.args.mm_projector_lr != 0: 181 | projector_parameters = [ 182 | name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 183 | if self.args.vision_tower_lr is not None and self.args.vision_tower_lr != 0: 184 | vision_tower_parameters = [ 185 | name for name, _ in opt_model.named_parameters() if "vision_tower" in name] 186 | optimizer_grouped_parameters = [ 187 | { 188 | "params": [ 189 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad) 190 | ], 191 | "weight_decay": self.args.weight_decay, 192 | }, 193 | { 194 | "params": [ 195 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad) 196 | ], 197 | "weight_decay": self.args.weight_decay, 198 | "lr": self.args.vision_tower_lr, 199 | }, 200 | { 201 | "params": [ 202 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad) 203 | ], 204 | "weight_decay": 0.0, 205 | }, 206 | { 207 | "params": [ 208 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad) 209 | ], 210 | "weight_decay": 0.0, 211 | "lr": self.args.vision_tower_lr, 212 | }, 213 | { 214 | "params": [ 215 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 216 | ], 217 | "weight_decay": self.args.weight_decay, 218 | "lr": self.args.mm_projector_lr, 219 | }, 220 | { 221 | "params": [ 222 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 223 | ], 224 | "weight_decay": 0.0, 225 | "lr": self.args.mm_projector_lr, 226 | }, 227 | ] 228 | else: 229 | optimizer_grouped_parameters = [ 230 | { 231 | "params": [ 232 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 233 | ], 234 | "weight_decay": self.args.weight_decay, 235 | }, 236 | { 237 | "params": [ 238 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 239 | ], 240 | "weight_decay": 0.0, 241 | }, 242 | { 243 | "params": [ 244 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 245 | ], 246 | "weight_decay": self.args.weight_decay, 247 | "lr": self.args.mm_projector_lr, 248 | }, 249 | { 250 | "params": [ 251 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 252 | ], 253 | "weight_decay": 0.0, 254 | "lr": self.args.mm_projector_lr, 255 | }, 256 | ] 257 | else: 258 | optimizer_grouped_parameters = [ 259 | { 260 | "params": [ 261 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 262 | ], 263 | "weight_decay": self.args.weight_decay, 264 | }, 265 | { 266 | "params": [ 267 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 268 | ], 269 | "weight_decay": 0.0, 270 | }, 271 | ] 272 | 273 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( 274 | self.args) 275 | 276 | self.optimizer = optimizer_cls( 277 | optimizer_grouped_parameters, **optimizer_kwargs) 278 | if optimizer_cls.__name__ == "Adam8bit": 279 | import bitsandbytes 280 | 281 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 282 | 283 | skipped = 0 284 | for module in opt_model.modules(): 285 | if isinstance(module, nn.Embedding): 286 | skipped += sum({p.data_ptr(): p.numel() 287 | for p in module.parameters()}.values()) 288 | logger.info( 289 | f"skipped {module}: {skipped/2**20}M params") 290 | manager.register_module_override( 291 | module, "weight", {"optim_bits": 32}) 292 | logger.debug( 293 | f"bitsandbytes: will optimize {module} in fp32") 294 | logger.info(f"skipped: {skipped/2**20}M params") 295 | 296 | return self.optimizer 297 | 298 | def _save_checkpoint(self, model, trial, metrics=None): 299 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 300 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 301 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 302 | 303 | run_dir = self._get_output_dir(trial=trial) 304 | output_dir = os.path.join(run_dir, checkpoint_folder) 305 | 306 | # Only save Adapter 307 | keys_to_match = ['mm_projector', 'vision_resampler'] 308 | if getattr(self.args, "use_im_start_end", False): 309 | keys_to_match.extend(['embed_tokens', 'embed_in']) 310 | 311 | weight_to_save = get_mm_adapter_state_maybe_zero_3( 312 | self.model.named_parameters(), keys_to_match) 313 | 314 | if self.args.local_rank == 0 or self.args.local_rank == -1: 315 | self.model.config.save_pretrained(output_dir) 316 | torch.save(weight_to_save, os.path.join( 317 | output_dir, 'mm_projector.bin')) 318 | else: 319 | super(VisionLLaVATrainer, self)._save_checkpoint( 320 | model, trial, metrics) 321 | 322 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 323 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 324 | pass 325 | else: 326 | super(VisionLLaVATrainer, self)._save(output_dir, state_dict) 327 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "0.0.1" 8 | description = "Advanced large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2", "torchvision==0.16.2", 17 | "transformers==4.39.3", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.27.2", "peft", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.16.0", "gradio_client==0.8.1", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | train = ["deepspeed==0.13.1", "ninja", "wandb"] 27 | build = ["build", "twine"] 28 | 29 | [tool.setuptools.packages.find] 30 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 31 | 32 | [tool.wheel] 33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] -------------------------------------------------------------------------------- /scripts/eval-lmms.sh: -------------------------------------------------------------------------------- 1 | # you can refer to this issue: https://github.com/EvolvingLMMs-Lab/lmms-eval/issues/15 2 | 3 | export HF_HOME=/path-to-save-dir 4 | 5 | path=/path-to-your-model 6 | 7 | accelerate launch --num_processes=8 -m lmms_eval \ 8 | --model llava \ 9 | --model_args pretrained="${path}" \ 10 | --tasks mme \ 11 | --batch_size 1 \ 12 | --log_samples --log_samples_suffix convllava \ 13 | --output_path ./logs/ -------------------------------------------------------------------------------- /scripts/evaluation.sh: -------------------------------------------------------------------------------- 1 | export LMUData=path_to_your_saved_data 2 | export OMP_NUM_THREADS=1 3 | 4 | eval_dataset="MMBench_DEV_EN" 5 | 6 | llava_path=path_to_your_weights 7 | 8 | work_dir=path_to_your_work_dir 9 | gpu=2 10 | 11 | # if you want to use chatgpt to evaluate, you need to set OPENAI_API_KEY 12 | # export OPENAI_API_KEY="sk-1234" 13 | 14 | torchrun --nproc-per-node=${gpu} llava/eval/run.py \ 15 | --data ${eval_dataset} \ 16 | --model llava_v1.5_7b \ 17 | --verbose \ 18 | --work-dir ${work_dir}/vlmeval \ 19 | --llava-path=${llava_path} 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/refcoco.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT=/path/to/convllava 2 | 3 | torchrun \ 4 | --nnodes=1 \ 5 | --node_rank=0 \ 6 | --master_addr=127.0.0.1 \ 7 | --nproc_per_node=8 \ 8 | --master_port=25908 \ 9 | llava/eval/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir output 10 | 11 | 12 | -------------------------------------------------------------------------------- /scripts/stage_1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir 3 | 4 | deepspeed llava/train/train.py \ 5 | --deepspeed ./scripts/zero2.json \ 6 | --model_name_or_path lmsys/vicuna-7b-v1.5 \ 7 | --version v1 \ 8 | --dataset "dataset_you_want_train" \ 9 | --vision_tower path_to_original_convnext \ 10 | --mm_vision_resolution 768 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --freeze_backbone True \ 13 | --vision_add_five_stage 6 \ 14 | --vision_five_stage_width 3072 \ 15 | --mm_vision_select_layer -1 \ 16 | --mm_use_im_start_end False \ 17 | --mm_use_im_patch_token False \ 18 | --bf16 True \ 19 | --output_dir ./checkpoints/convllava/stage1 \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 32 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --image_aspect_ratio pad \ 25 | --group_by_modality_length True \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 24000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 3e-4 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb 41 | -------------------------------------------------------------------------------- /scripts/stage_2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir 3 | 4 | deepspeed llava/train/train.py \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path path_to_stage1_llm \ 7 | --version v1 \ 8 | --dataset "dataset_you_want_train" \ 9 | --vision_tower path_to_stage1_convnext\ 10 | --mm_projector_type mlp2x_gelu \ 11 | --mm_vision_resolution 768 \ 12 | --mm_vision_select_layer -1 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/convllava/stage2 \ 17 | --tune_vision_tower True \ 18 | --tune_vit_from_layer 2 \ 19 | --tune_entire_model False \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 32 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --image_aspect_ratio pad \ 25 | --group_by_modality_length True \ 26 | --evaluation_strategy "no" \ 27 | --save_strategy "steps" \ 28 | --save_steps 24000 \ 29 | --save_total_limit 1 \ 30 | --learning_rate 1e-3 \ 31 | --weight_decay 0. \ 32 | --warmup_ratio 0.03 \ 33 | --lr_scheduler_type "cosine" \ 34 | --logging_steps 1 \ 35 | --tf32 True \ 36 | --model_max_length 2048 \ 37 | --gradient_checkpointing True \ 38 | --dataloader_num_workers 4 \ 39 | --lazy_preprocess True \ 40 | --report_to wandb 41 | -------------------------------------------------------------------------------- /scripts/stage_3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir 3 | 4 | deepspeed llava/train/train.py \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path path_to_stage2_llm \ 7 | --version v1 \ 8 | --dataset "dataset_you_want_train" \ 9 | --vision_tower path_to_stage2_convnext \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --mm_vision_resolution 768 \ 12 | --mm_vision_select_layer -1 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/convllava/stage3 \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 16 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --image_aspect_ratio pad \ 22 | --group_by_modality_length True \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------