├── .gitignore
├── LICENSE
├── README.md
├── README_zh.md
├── asset
├── WeChat.png
└── method.png
├── docs
├── Data.md
├── Evaluation.md
└── Model_zoo.md
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── data
│ ├── __init__.py
│ └── data.py
├── eval
│ ├── evaluate_grounding.py
│ └── run.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── language_model
│ │ └── llava_llama.py
│ ├── llava_arch.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ ├── convnext_encoder.py
│ │ ├── lknet_encoder.py
│ │ ├── siglip_encoder.py
│ │ └── unireplknet
│ │ │ ├── __init__.py
│ │ │ └── unireplknet_encoder.py
│ └── multimodal_projector
│ │ └── builder.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ └── test_message.py
├── train
│ ├── llava_trainer.py
│ └── train.py
└── utils.py
├── pyproject.toml
└── scripts
├── eval-lmms.sh
├── evaluation.sh
├── refcoco.sh
├── stage_1.sh
├── stage_2.sh
├── stage_3.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # Images
156 | images/
157 |
158 | *.tar
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [Chunjiang Ge](https://john-ge.github.io/), [Sijie Cheng](https://adacheng.github.io/), Ziming Wang, Jiale Yuan, Yuan Gao
6 |
7 | Jun Song, Shiji Song, [Gao Huang](https://www.gaohuang.net/), Bo Zheng
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | [ English | 中文 ]
36 |
37 | ## Abstract
38 |
39 | High-resolution Large Multimodel Models (LMM) encounter the challenges of excessive visual tokens and quadratic visual complexity. Current high-resolution LMMs address the quadratic complexity while still generating excessive visual tokens. **However, the redundancy in visual tokens is the key problem as it leads to more substantial compute.** To mitigate this, we propose ConvLLaVA, which employs ConvNeXt, a hierarchical backbone, as the visual encoder of LMM to replace Vision Transformer (ViT). **ConvLLaVA compresses high-resolution images into information-rich visual features, effectively avoiding the generation of excessive visual tokens.** To enhance the capabilities of ConvLLaVA, we propose two critical optimizations.
40 |
41 | - Since the low-resolution pretrained ConvNeXt underperforms when directly applied on high resolution, we **update** it to merge the gap.
42 | - Furthermore, since ConvNeXt's original compression ratio is insufficient for much higher resolution inputs, we train a **successive stage** to further compress the visual tokens, effectively reducing redundancy.
43 |
44 | **These optimizations enable ConvLLaVA to support inputs of 1536x1536 resolution while generating only 576 visual tokens, accommodating images of arbitrary aspect ratios.** [Experimental results](#model-zoo) demonstrate that our method achieves competitive performance with state-of-the-art models on mainstream benchmarks.
45 |
46 |
47 |

48 |
49 |
50 | Comparison between LLaVA and ConvLLaVA.
51 |
52 |
53 | ## Release :loudspeaker:
54 |
55 | - **2024/05/25**: Checkpoints are released.
56 | - **2024/04/17**: Our code is released.
57 |
58 | [](mailto:gecj20@mails.tsinghua.edu.cn)
59 | If you are interested in Large Multimodal Models or you have great ideas, please feel free to email with me: [Chunjiang Ge](mailto:gecj20@mails.tsinghua.edu.cn).
60 |
61 | [](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
62 | **Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
63 |
64 | ## Contents
65 | - [Abstract](#abstract)
66 | - [Release :loudspeaker:](#release-loudspeaker)
67 | - [Contents](#contents)
68 | - [TODO](#todo)
69 | - [Install](#install)
70 | - [Model Zoo](#model-zoo)
71 | - [Dataset](#dataset)
72 | - [Train](#train)
73 | - [Evaluation](#evaluation)
74 | - [Citation](#citation)
75 | - [Acknowledgement](#acknowledgement)
76 |
77 | ## TODO
78 |
79 | - [ ] Add [LMMs-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) supports.
80 | - [ ] Add [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) supports.
81 | - [ ] Add [xtuner](https://github.com/InternLM/xtuner) supports.
82 | - [x] Release weights.
83 | - [ ] Release inference code.
84 |
85 | ## Install
86 |
87 | 1. Clone this repository and navigate to ConvLLaVA folder
88 | ```bash
89 | git clone https://github.com/alibaba/conv-llava
90 | cd conv-llava
91 | ```
92 |
93 | 1. Install Package
94 | ```bash
95 | conda create -n convllava python=3.11 -y
96 | conda activate convllava
97 | pip install --upgrade pip # enable PEP 660 support
98 | pip install -e .
99 | ```
100 |
101 | 3. Install additional packages for training cases
102 | ```bash
103 | pip install -e ".[train]"
104 | pip install flash-attn --no-build-isolation
105 | ```
106 |
107 | ## Model Zoo
108 |
109 | The performance on mainstream benchmarks are shown below:
110 |
111 |
112 |
113 | Method |
114 | Resolution |
115 | Visual Tokens |
116 | LLM |
117 | MME |
118 | MMB |
119 | SEED |
120 | RealWorldQA |
121 | MMMU |
122 | MMVet |
123 | Text |
124 | Doc |
125 | POPE |
126 |
127 |
128 |
129 | ConvLLaVA |
130 | 768 |
131 | 144 |
132 | 7B |
133 | 1541 |
134 | 68 |
135 | 68.8 |
136 | 55.9 |
137 | 36.3 |
138 | 44.8 |
139 | 59.1 |
140 | 44.8 |
141 | 87.3 |
142 |
143 |
144 | ConvLLaVA |
145 | 1024 |
146 | 256 |
147 | 7B |
148 | 1553 |
149 | 68.8 |
150 | 69.3 |
151 | 58.8 |
152 | 35.1 |
153 | 44.4 |
154 | 62.5 |
155 | 48.5 |
156 | 87.7 |
157 |
158 |
159 | ConvLLaVA |
160 | 1536 |
161 | 576 |
162 | 7B |
163 | 1575 |
164 | 68.7 |
165 | 70.2 |
166 | 59.9 |
167 | 35.8 |
168 | 45.9 |
169 | 65.8 |
170 | 59 |
171 | 87.3 |
172 |
173 |
174 |
175 |
176 |
177 | Method |
178 | Resolution |
179 | Visual Tokens |
180 | LLM |
181 | RefCOCO |
182 | RefCOCO+ |
183 | RefCOCOg |
184 | Avg |
185 |
186 |
187 | val |
188 | test-A |
189 | test-B |
190 | val |
191 | test-A |
192 | test-B |
193 | val |
194 | test |
195 |
196 |
197 |
198 | ConvLLaVA |
199 | 768 |
200 | 144 |
201 | 7B |
202 | 84.5 |
203 | 89.0 |
204 | 79.2 |
205 | 77.7 |
206 | 84.9 |
207 | 69.7 |
208 | 79.8 |
209 | 79.7 |
210 | 80.6 |
211 |
212 |
213 | ConvLLaVA |
214 | 1024 |
215 | 256 |
216 | 7B |
217 | 85.5 |
218 | 89.6 |
219 | 78.8 |
220 | 79.3 |
221 | 86.1 |
222 | 70.3 |
223 | 80.6 |
224 | 81.2 |
225 | 81.4 |
226 |
227 |
228 | ConvLLaVA |
229 | 1536 |
230 | 576 |
231 | 7B |
232 | 86.5 |
233 | 90.6 |
234 | 80.5 |
235 | 80.0 |
236 | 86.8 |
237 | 71.5 |
238 | 82.0 |
239 | 82.4 |
240 | 82.3 |
241 |
242 |
243 |
244 | Please check out our [Model Zoo](https://github.com/alibaba/conv-llava/blob/main/docs/Model_zoo.md) for all public ConvLLaVA checkpoints, and the instructions of how to use the weights.
245 |
246 | ## Dataset
247 |
248 | Data we use is introduded in [Data.md](https://github.com/alibaba/conv-llava/blob/main/docs/Data.md).
249 |
250 | ## Train
251 |
252 | We use the following hyperparameters for training ConvLLaVA.
253 |
254 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 |
255 | | --------------- | ------- | ------- | ------- |
256 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 |
257 | | Batch Size | 256 | 256 | 128 |
258 | | Epochs | 1 | 1 | 1 |
259 | | Warmup Ratio | 0.03 | 0.03 | 0.03 |
260 | | Weight Decay | 0 | 0 | 0 |
261 | | Optimizer | AdamW | AdamW | AdamW |
262 |
263 | The training scripts are in the [scripts](https://github.com/alibaba/conv-llava/tree/main/scripts):
264 |
265 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh)
266 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh)
267 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh)
268 |
269 | ## Evaluation
270 |
271 | We support [VLMEVALKIT](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) to evaluate our model now. See [Evaluation.md](https://github.com/alibaba/conv-llava/blob/main/docs/Evaluation.md) for more details.
272 |
273 | ## Citation
274 |
275 | If you find LLaVA useful for your research and applications, please cite using this BibTeX:
276 |
277 | ```bibtex
278 | @misc{ge2024convllava,
279 | title={ConvLLaVA: Hierarchical Backbones as Visual
280 | Encoder for Large Multimodal Models},
281 | author={Chunjiang Ge, Sijie Cheng, Ziming Wang, Jiale Yuan, Yuan Gao, Jun Song, Shiji Song, Gao Huang, Bo Zheng},
282 | archivePrefix={arXiv},
283 | primaryClass={cs.CV}
284 | year={2024}
285 | eprint={2045.15738},
286 | }
287 | ```
288 |
289 | ## Acknowledgement
290 |
291 | - [Vicuna](https://github.com/lm-sys/FastChat): the codebase LLaVA built upon, and our base model Vicuna-13B that has the amazing language capabilities!
292 | - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon.
293 |
--------------------------------------------------------------------------------
/README_zh.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | [Chunjiang Ge](https://john-ge.github.io/), [Sijie Cheng](https://adacheng.github.io/), Ziming Wang, Jiale Yuan, Yuan Gao
6 |
7 | Jun Song, Shiji Song, [Gao Huang](https://www.gaohuang.net/), Bo Zheng
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 | [ English | 中文 ]
36 |
37 | ## 摘要 :bulb:
38 |
39 | 高分辨率多模态大模型(LMM)面临视觉token过多和视觉平方复杂度的挑战。当前的高分辨率LMM通常能够解决二次复杂度问题,却会生成过量的视觉token。**然而,过多的视觉token才是更关键的问题,因为它会导致更显著的计算开销。** 为了解决这个问题,我们提出了ConvLLaVA,它采用层次化的主干网络ConvNeXt作为LMM的视觉编码器,以替代Vision Transformer(ViT)。**ConvLLaVA将高分辨率图像压缩成富含信息的视觉特征,有效避免了生成过多的视觉token。** 为了增强ConvLLaVA的能力,我们提出了两项关键优化措施。
40 |
41 | - 由于低分辨率预训练的ConvNeXt在直接应用于高分辨率时表现不佳,**我们更新它以弥合这一差距。**
42 | - 此外,由于ConvNeXt原有的压缩比对于更高分辨率的输入来说不足,**我们训练了一个新的stage,以进一步压缩视觉token**,有效减少冗余。
43 |
44 | **这些优化使得ConvLLaVA能够支持1536x1536分辨率的输入,同时仅生成576个视觉token,并适应任意宽高比的图像。** [实验结果](#model-zoo)显示,我们的方法在主流基准测试上与最先进的模型相比取得了竞争性的性能。
45 |
46 |
47 |

48 |
49 |
50 | LLaVA 和 ConvLLaVA 结构上的对比
51 |
52 |
53 |
54 | [](mailto:gecj20@mails.tsinghua.edu.cn)
55 | 如果你对多模态大模型感兴趣,或者你有很好的想法,请你联系我:[Chunjiang Ge](mailto:gecj20@mails.tsinghua.edu.cn).
56 |
57 | [](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
58 | **Usage and License Notices**: This project utilizes certain datasets and checkpoints that are subject to their respective original licenses. Users must comply with all terms and conditions of these original licenses, including but not limited to the [OpenAI Terms of Use](https://openai.com/policies/terms-of-use) for the dataset and the specific licenses for base language models for checkpoints trained using the dataset (e.g. [Llama community license](https://ai.meta.com/llama/license/) for LLaMA-2 and Vicuna-v1.5). This project does not impose any additional constraints beyond those stipulated in the original licenses. Furthermore, users are reminded to ensure that their use of the dataset and checkpoints is in compliance with all applicable laws and regulations.
59 |
60 | ## 内容
61 | - [摘要 :bulb:](#摘要-bulb)
62 | - [内容](#内容)
63 | - [计划](#计划)
64 | - [安装](#安装)
65 | - [模型库](#模型库)
66 | - [数据集](#数据集)
67 | - [训练](#训练)
68 | - [评测](#评测)
69 | - [引用](#引用)
70 | - [致谢](#致谢)
71 |
72 | ## 计划
73 |
74 | - [ ] Add [LMMs-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) supports.
75 | - [ ] Add [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) supports.
76 | - [ ] Add [xtuner](https://github.com/InternLM/xtuner) supports.
77 | - [x] Release weights.
78 | - [ ] Release inference code.
79 |
80 | ## 安装
81 |
82 | 1. Clone this repository and navigate to ConvLLaVA folder
83 | ```bash
84 | git clone https://github.com/alibaba/conv-llava
85 | cd conv-llava
86 | ```
87 |
88 | 1. Install Package
89 | ```bash
90 | conda create -n convllava python=3.11 -y
91 | conda activate convllava
92 | pip install --upgrade pip # enable PEP 660 support
93 | pip install -e .
94 | ```
95 |
96 | 3. Install additional packages for training cases
97 | ```bash
98 | pip install -e ".[train]"
99 | pip install flash-attn --no-build-isolation
100 | ```
101 |
102 | ## 模型库
103 |
104 | 我们的模型的在一些测试基准上的性能如下:
105 |
106 |
107 |
108 | Method |
109 | Resolution |
110 | Visual Tokens |
111 | LLM |
112 | MME |
113 | MMB |
114 | SEED |
115 | RealWorldQA |
116 | MMMU |
117 | MMVet |
118 | Text |
119 | Doc |
120 | POPE |
121 |
122 |
123 |
124 | ConvLLaVA |
125 | 768 |
126 | 144 |
127 | 7B |
128 | 1541 |
129 | 68 |
130 | 68.8 |
131 | 55.9 |
132 | 36.3 |
133 | 44.8 |
134 | 59.1 |
135 | 44.8 |
136 | 87.3 |
137 |
138 |
139 | ConvLLaVA |
140 | 1024 |
141 | 256 |
142 | 7B |
143 | 1553 |
144 | 68.8 |
145 | 69.3 |
146 | 58.8 |
147 | 35.1 |
148 | 44.4 |
149 | 62.5 |
150 | 48.5 |
151 | 87.7 |
152 |
153 |
154 | ConvLLaVA |
155 | 1536 |
156 | 576 |
157 | 7B |
158 | 1575 |
159 | 68.7 |
160 | 70.2 |
161 | 59.9 |
162 | 35.8 |
163 | 45.9 |
164 | 65.8 |
165 | 59 |
166 | 87.3 |
167 |
168 |
169 |
170 |
171 |
172 | Method |
173 | Resolution |
174 | Visual Tokens |
175 | LLM |
176 | RefCOCO |
177 | RefCOCO+ |
178 | RefCOCOg |
179 | Avg |
180 |
181 |
182 | val |
183 | test-A |
184 | test-B |
185 | val |
186 | test-A |
187 | test-B |
188 | val |
189 | test |
190 |
191 |
192 |
193 | ConvLLaVA |
194 | 768 |
195 | 144 |
196 | 7B |
197 | 84.5 |
198 | 89.0 |
199 | 79.2 |
200 | 77.7 |
201 | 84.9 |
202 | 69.7 |
203 | 79.8 |
204 | 79.7 |
205 | 80.6 |
206 |
207 |
208 | ConvLLaVA |
209 | 1024 |
210 | 256 |
211 | 7B |
212 | 85.5 |
213 | 89.6 |
214 | 78.8 |
215 | 79.3 |
216 | 86.1 |
217 | 70.3 |
218 | 80.6 |
219 | 81.2 |
220 | 81.4 |
221 |
222 |
223 | ConvLLaVA |
224 | 1536 |
225 | 576 |
226 | 7B |
227 | 86.5 |
228 | 90.6 |
229 | 80.5 |
230 | 80.0 |
231 | 86.8 |
232 | 71.5 |
233 | 82.0 |
234 | 82.4 |
235 | 82.3 |
236 |
237 |
238 |
239 | 我们的 [Model Zoo](https://github.com/alibaba/conv-llava/blob/main/docs/Model_zoo.md) 中包含了主要的权重和下载方式,并有说明如何使用这些权重。
240 |
241 | ## 数据集
242 |
243 | 我们实验用到的数据在 [Data.md](https://github.com/alibaba/conv-llava/blob/main/docs/Data.md) 中有介绍。
244 |
245 | ## 训练
246 |
247 | 训练的超参数如下:
248 |
249 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 |
250 | | --------------- | ------- | ------- | ------- |
251 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 |
252 | | Batch Size | 256 | 256 | 128 |
253 | | Epochs | 1 | 1 | 1 |
254 | | Warmup Ratio | 0.03 | 0.03 | 0.03 |
255 | | Weight Decay | 0 | 0 | 0 |
256 | | Optimizer | AdamW | AdamW | AdamW |
257 |
258 | 训练脚本在文件夹 [scripts](https://github.com/alibaba/conv-llava/tree/main/scripts) 中:
259 |
260 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh)
261 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh)
262 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh)
263 |
264 | ## 评测
265 |
266 | 我们目前支持 [VLMEVALKIT](https://github.com/open-compass/VLMEvalKit) 和 [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) 来测试模型。请看 [Evaluation.md](https://github.com/alibaba/conv-llava/blob/main/docs/Evaluation.md) 了解更多细节.
267 |
268 | ## 引用
269 |
270 | 如果你认为我们的工作有所帮助,请你通过下面的 BibTeX 来引用我们的工作:
271 |
272 | ```bibtex
273 | @misc{ge2024convllava,
274 | title={ConvLLaVA: Hierarchical Backbones as Visual
275 | Encoder for Large Multimodal Models},
276 | author={Chunjiang Ge, Sijie Cheng, Ziming Wang, Jiale Yuan, Yuan Gao, Jun Song, Shiji Song, Gao Huang, Bo Zheng},
277 | archivePrefix={arXiv},
278 | primaryClass={cs.CV}
279 | year={2024}
280 | eprint={2045.15738},
281 | }
282 | ```
283 |
284 | ## 致谢
285 |
286 | - [Vicuna](https://github.com/lm-sys/FastChat): the codebase LLaVA built upon, and our base model Vicuna-13B that has the amazing language capabilities!
287 | - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon.
288 |
--------------------------------------------------------------------------------
/asset/WeChat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/asset/WeChat.png
--------------------------------------------------------------------------------
/asset/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/asset/method.png
--------------------------------------------------------------------------------
/docs/Data.md:
--------------------------------------------------------------------------------
1 | # Data
2 |
3 | We use the following hyperparameters for training ConvLLaVA.
4 |
5 | | Hyperparameters | Stage 1 | Stage 2 | Stage 3 |
6 | | --------------- | ------- | ------- | ------- |
7 | | Learning Rate | 3e-4 | 2e-5 | 2e-5 |
8 | | Batch Size | 256 | 256 | 128 |
9 | | Epochs | 1 | 1 | 1 |
10 | | Warmup Ratio | 0.03 | 0.03 | 0.03 |
11 | | Weight Decay | 0 | 0 | 0 |
12 | | Optimizer | AdamW | AdamW | AdamW |
13 |
14 | ## Projector Initialzation
15 |
16 | We use captions from ShareGPT4V-PT, ShareGPT4V, ALLAVA.
17 |
18 | ## Vision Language Pretraining
19 |
20 | We use ShareGPT4V-PT, ShareGPT4V, ALLAVA and a part of VFLAN.
21 |
22 | ## Instrcution Tuning
23 |
24 | We use LLaVA-1.5 sft 665k dataset. We would update the results when LLaVA-NExT released.
25 |
26 | ## Prepare Images
27 |
28 | First, download all images and instrcution files.
29 |
30 | - ALLaVA: [images](https://huggingface.co/datasets/FreedomIntelligence/ALLaVA-4V)
31 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip)
32 | - LLaVA: [llava](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K)
33 | - WebData: [images](https://drive.google.com/drive/folders/1tCUQ-sq6vdshZVkF0ZeF3K4eztkXJgax?usp=sharing). Only for academic usage.
34 | - SAM: [images](https://ai.meta.com/datasets/segment-anything-downloads/). We only use 000000~000050.tar for now. If you find it is slow for you to donnload in China, please refer to [opendatalab](https://opendatalab.com/OpenDataLab/SA-1B) to download it.
35 | - GQA: [images](https://downloads.cs.stanford.edu/nlp/data/gqa/images.zip)
36 | - OCR-VQA: [download script](https://drive.google.com/drive/folders/1_GYPY5UkUy7HIcR0zq3ZCFgeZN7BAfm_?usp=sharing). We save all files as `.jpg`
37 | - TextVQA: [trainvalimages](https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip)
38 | - VisualGenome: [part1](https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip), [part2](https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip)
39 | - vflan: [vflan](https://huggingface.co/datasets/Vision-Flan/vision-flan_191-task_1k)
40 |
41 | Then, organize the data as follows:
42 |
43 | ```none
44 | ShareGPT4V
45 | ├── ...
46 | ├── data
47 | │ ├── allava
48 | │ │ ├── allava_laion
49 | │ │ │ ├── images
50 | │ │ │ ├── ALLaVA-Caption-LAION-4V.json
51 | │ │ │ ├── ALLaVA-Instruct-LAION-4V.json
52 | │ │ ├── allava_vflan
53 | │ │ │ ├── ALLaVA-Caption-VFLAN-4V.json
54 | │ │ │ ├── ALLaVA-Instruct-VFLAN-4V.json
55 | │ ├── coco
56 | │ │ ├── train2017
57 | │ ├── llava
58 | │ │ ├── llava_v1_5_mix665k.json
59 | │ ├── sam
60 | │ │ ├── images
61 | │ ├── gqa
62 | │ │ ├── images
63 | │ ├── ocr_vqa
64 | │ │ ├── images
65 | │ ├── textvqa
66 | │ │ ├── train_images
67 | │ ├── vg
68 | │ │ ├── VG_100K
69 | │ │ ├── VG_100K_2
70 | │ ├── vflan
71 | │ │ ├── images_191task_1k
72 | │ │ ├── annotation_191-task_1k.json
73 | │ ├── sharegpt4v
74 | │ │ ├── share-captioner_coco_lcs_sam_1246k_1107.json
75 | │ │ ├── sharegpt4v_instruct_gpt4-vision_cap100k.json
76 | │ ├── share_textvqa
77 | │ │ ├── images
78 | │ ├── web-celebrity
79 | │ │ ├── images
80 | │ ├── web-landmark
81 | │ │ ├── images
82 | │ ├── wikiart
83 | │ │ ├── images
84 | ├── ...
85 | ```
86 |
87 | If you find download ocrvqa images slow. You could refer to this [issue](https://github.com/haotian-liu/LLaVA/issues/931).
88 | Use multiprocessing to speed up:
89 |
90 | ```python
91 | import concurrent.futures
92 | def download_image(k):
93 | ext = os.path.splitext(data[k]['imageURL'])[1]
94 | outputFile = 'images/%s%s' % (k, ext)
95 |
96 | # Only download the image if it doesn't exist
97 | if not os.path.exists(outputFile):
98 | ureq.urlretrieve(data[k]['imageURL'], outputFile)
99 |
100 |
101 | if download == 1:
102 | # Create the directory if it doesn't exist
103 | if not os.path.exists('./images'):
104 | os.mkdir('./images')
105 |
106 | # Create a thread pool and download the images in parallel
107 | with concurrent.futures.ThreadPoolExecutor() as executor:
108 | executor.map(download_image, data.keys())
109 | ```
110 |
111 | For ocrvqa, some git images should be transfered to jpg. You could follow bwloe code:
112 |
113 | ```python
114 | import os
115 | from PIL import Image
116 |
117 | def convert_gif_to_jpg(folder_path):
118 | for filename in os.listdir(folder_path):
119 | if filename.endswith('.gif'):
120 | file_path = os.path.join(folder_path, filename)
121 | with Image.open(file_path) as img:
122 | jpg_filename = os.path.splitext(filename)[0] + '.jpg'
123 | jpg_path = os.path.join(folder_path, jpg_filename)
124 | img.convert('RGB').save(jpg_path, 'JPEG', quality=95)
125 | print(f'Converted {filename} to {jpg_filename}')
126 |
127 | folder_path = 'path_to_your_folder'
128 | convert_gif_to_jpg(folder_path)
129 | ```
130 |
131 | ## Data Configuration
132 |
133 | You could modify the file [data.py](conv-llava/llava/data/data_blending.py) to add the datasets. Replace with the true path:
134 |
135 | ```python
136 | def build_sharegpt4v(tokenizer, data_args):
137 | data_path = 'path_to_sharegpt4v_pt.json'
138 | image_folder = 'folder_to_sharegpt4v_pt'
139 | dataset = SampleDataset(data_path, tokenizer, data_args,
140 | image_folder)
141 | return dataset
142 | ```
143 |
--------------------------------------------------------------------------------
/docs/Evaluation.md:
--------------------------------------------------------------------------------
1 | # Evaluation
2 |
3 | ## VLMEvalKit
4 |
5 | We use VLMEVALKIT as the evaluation tools. Please refer to [QuickStart](https://github.com/open-compass/VLMEvalKit/blob/main/Quickstart.md) for installation.
6 |
7 | We evaluate the models with scripts in [evaluation.sh](scripts/evaluation.sh). You could modify the parameters for evaluating different benchmarks.
8 |
9 | You should use the file [run.py](conv-llava/llava/eval/run.py) to replace with the original run file to evaluate the model.
10 |
11 | ```bash
12 | eval_dataset="MMVet" # need openai api key
13 | eval_dataset="MME MMBench_DEV_EN"
14 |
15 | # set the llava-path to the actual path of your convllava checkpoint
16 | ```
17 |
18 | We would contribute the VLMEVALKIT to support our model soon.
19 |
20 | ## lmms-eval
21 |
22 | If you want to use lmms-eval to evaluate the model. You need to first install the package:
23 |
24 | ```bash
25 | git clone https://github.com/EvolvingLMMs-Lab/lmms-eval
26 | cd lmms-eval
27 | pip install -e .
28 | ```
29 |
30 | You should use the file [eval-lmms.sh](conv-llava/llava/eval/eval-lmms.sh) to evaluate the model. You could modify the parameters for evaluating different benchmarks.
31 |
32 |
33 | ## RefCOCO
34 |
35 | If you are interested in RefCOCO, we provide the code in [refcoco.sh](scripts/refcoco.sh).
--------------------------------------------------------------------------------
/docs/Model_zoo.md:
--------------------------------------------------------------------------------
1 | # Model Zoo
2 |
3 | ## Performance
4 |
5 |
6 |
7 | Method |
8 | Resolution |
9 | Visual Tokens |
10 | LLM |
11 | MME |
12 | MMB |
13 | SEED |
14 | RealWorldQA |
15 | MMMU |
16 | MMVet |
17 | Text |
18 | Doc |
19 | POPE |
20 |
21 |
22 |
23 | ConvLLaVA |
24 | 768 |
25 | 144 |
26 | 7B |
27 | 1541 |
28 | 68 |
29 | 68.8 |
30 | 55.9 |
31 | 36.3 |
32 | 44.8 |
33 | 59.1 |
34 | 44.8 |
35 | 87.3 |
36 |
37 |
38 | ConvLLaVA |
39 | 1024 |
40 | 256 |
41 | 7B |
42 | 1553 |
43 | 68.8 |
44 | 69.3 |
45 | 58.8 |
46 | 35.1 |
47 | 44.4 |
48 | 62.5 |
49 | 48.5 |
50 | 87.7 |
51 |
52 |
53 | ConvLLaVA |
54 | 1536 |
55 | 576 |
56 | 7B |
57 | 1575 |
58 | 68.7 |
59 | 70.2 |
60 | 59.9 |
61 | 35.8 |
62 | 45.9 |
63 | 65.8 |
64 | 59 |
65 | 87.3 |
66 |
67 |
68 |
69 |
70 |
71 | Method |
72 | Resolution |
73 | Visual Tokens |
74 | LLM |
75 | RefCOCO |
76 | RefCOCO+ |
77 | RefCOCOg |
78 | Avg |
79 |
80 |
81 | val |
82 | test-A |
83 | test-B |
84 | val |
85 | test-A |
86 | test-B |
87 | val |
88 | test |
89 |
90 |
91 |
92 | ConvLLaVA |
93 | 768 |
94 | 144 |
95 | 7B |
96 | 84.5 |
97 | 89.0 |
98 | 79.2 |
99 | 77.7 |
100 | 84.9 |
101 | 69.7 |
102 | 79.8 |
103 | 79.7 |
104 | 80.6 |
105 |
106 |
107 | ConvLLaVA |
108 | 1024 |
109 | 256 |
110 | 7B |
111 | 85.5 |
112 | 89.6 |
113 | 78.8 |
114 | 79.3 |
115 | 86.1 |
116 | 70.3 |
117 | 80.6 |
118 | 81.2 |
119 | 81.4 |
120 |
121 |
122 | ConvLLaVA |
123 | 1536 |
124 | 576 |
125 | 7B |
126 | 86.5 |
127 | 90.6 |
128 | 80.5 |
129 | 80.0 |
130 | 86.8 |
131 | 71.5 |
132 | 82.0 |
133 | 82.4 |
134 | 82.3 |
135 |
136 |
137 |
138 | ## Download
139 |
140 | We release checkpoints after vision language pretraining and visual instruction tuning. You could directly use the sft model and finetune the vision language pretraining checkpoints on you own data.
141 |
142 | | model | Huggingface | ModelScope | WiseModel |
143 | | :------------: | :---------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------: | ------------------------------------------------------------------------------------- |
144 | | ConvLLaVA-768 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-768), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-768) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-768/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-768/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-768/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-768/intro) |
145 | | ConvLLaVA-1024 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-1024), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-1024) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1024/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-1024/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1024/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-1024/intro) |
146 | | ConvLLaVA-1536 | [pretrain](https://huggingface.co/ConvLLaVA/ConvLLaVA-pretrain-1536), [sft](https://huggingface.co/ConvLLaVA/ConvLLaVA-sft-1536) | [pretrain](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1536/summary), [sft](https://modelscope.cn/models/ConvLLaVA/ConvLLaVA-sft-1536/summary) | [pretrain](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-pretrain-1536/intro), [sft](https://wisemodel.cn/models/ConvLLaVA/ConvLLaVA-sft-1536/intro) |
147 |
148 | The **pretrain** above means the checkpoints are after the second stage **vision-language pretraining**. The **sft** above means the checkpoints are after the third stage **instruction tuning**.
149 |
150 | ## Usage of the scripts
151 |
152 | The three stages training scripts are listed below:
153 |
154 | - Projector Initialzation: [stage1](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_1.sh)
155 | - Vision Language Pretraining: [stage2](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_2.sh)
156 | - Instruction Tuning: [stage3](https://github.com/alibaba/conv-llava/tree/main/scripts/stage_3.sh)
157 |
158 | ### Customize training
159 |
160 | If you want to custimze your model, you can directly load the **second stage pretrained visual encoder and LLM** for instruction tuning. It takes about 5 hours to train the 768 resolution model with LLaVA-Instruct-665k on a single 8 A800 GPUs.
161 |
162 | ### Training from scratch
163 |
164 | If you wang to train from scratch, you could download our processed ConvNeXt model (modify from LAION ConvNeXt). Then follow the three stage training scripts to train the model.
165 |
166 | ConvNeXt: [huggingface](https://huggingface.co/ConvLLaVA/LAION-CLIP-ConvNeXt-Large-512), [modelscope](https://modelscope.cn/models/ConvLLaVA/LAION-CLIP-ConvNeXt-Large-512/summary)
167 |
168 | You need to modify the config from the folder to the resolution you want to train your model on:
169 |
170 | - config.json: image_size
171 | - preprocessor_config: size, crop_size
172 |
173 | Then load that weights and start training.
174 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 | import base64
5 | from io import BytesIO
6 | from PIL import Image
7 |
8 |
9 | class SeparatorStyle(Enum):
10 | """Different separator style."""
11 | SINGLE = auto()
12 | TWO = auto()
13 | MPT = auto()
14 | PLAIN = auto()
15 | LLAMA_2 = auto()
16 |
17 |
18 | @dataclasses.dataclass
19 | class Conversation:
20 | """A class that keeps all conversation history."""
21 | system: str
22 | roles: List[str]
23 | messages: List[List[str]]
24 | offset: int
25 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26 | sep: str = "###"
27 | sep2: str = None
28 | version: str = "Unknown"
29 |
30 | skip_next: bool = False
31 |
32 | def get_prompt(self):
33 | messages = self.messages
34 | if len(messages) > 0 and type(messages[0][1]) is tuple:
35 | messages = self.messages.copy()
36 | init_role, init_msg = messages[0].copy()
37 | init_msg = init_msg[0].replace("", "").strip()
38 | if 'mmtag' in self.version:
39 | messages[0] = (init_role, init_msg)
40 | messages.insert(0, (self.roles[0], ""))
41 | messages.insert(1, (self.roles[1], "Received."))
42 | else:
43 | messages[0] = (init_role, "\n" + init_msg)
44 |
45 | if self.sep_style == SeparatorStyle.SINGLE:
46 | ret = self.system + self.sep
47 | for role, message in messages:
48 | if message:
49 | if type(message) is tuple:
50 | message, _, _ = message
51 | ret += role + ": " + message + self.sep
52 | else:
53 | ret += role + ":"
54 | elif self.sep_style == SeparatorStyle.TWO:
55 | seps = [self.sep, self.sep2]
56 | ret = self.system + seps[0]
57 | for i, (role, message) in enumerate(messages):
58 | if message:
59 | if type(message) is tuple:
60 | message, _, _ = message
61 | ret += role + ": " + message + seps[i % 2]
62 | else:
63 | ret += role + ":"
64 | elif self.sep_style == SeparatorStyle.MPT:
65 | ret = self.system + self.sep
66 | for role, message in messages:
67 | if message:
68 | if type(message) is tuple:
69 | message, _, _ = message
70 | ret += role + message + self.sep
71 | else:
72 | ret += role
73 | elif self.sep_style == SeparatorStyle.LLAMA_2:
74 | def wrap_sys(
75 | msg): return f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
76 |
77 | def wrap_inst(msg): return f"[INST] {msg} [/INST]"
78 | ret = ""
79 |
80 | for i, (role, message) in enumerate(messages):
81 | if i == 0:
82 | assert message, "first message should not be none"
83 | assert role == self.roles[0], "first message should come from user"
84 | if message:
85 | if type(message) is tuple:
86 | message, _, _ = message
87 | if i == 0:
88 | message = wrap_sys(self.system) + message
89 | if i % 2 == 0:
90 | message = wrap_inst(message)
91 | ret += self.sep + message
92 | else:
93 | ret += " " + message + " " + self.sep2
94 | else:
95 | ret += ""
96 | ret = ret.lstrip(self.sep)
97 | elif self.sep_style == SeparatorStyle.PLAIN:
98 | seps = [self.sep, self.sep2]
99 | ret = self.system
100 | for i, (role, message) in enumerate(messages):
101 | if message:
102 | if type(message) is tuple:
103 | message, _, _ = message
104 | ret += message + seps[i % 2]
105 | else:
106 | ret += ""
107 | else:
108 | raise ValueError(f"Invalid style: {self.sep_style}")
109 |
110 | return ret
111 |
112 | def append_message(self, role, message):
113 | self.messages.append([role, message])
114 |
115 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
116 | if image_process_mode == "Pad":
117 | def expand2square(pil_img, background_color=(122, 116, 104)):
118 | width, height = pil_img.size
119 | if width == height:
120 | return pil_img
121 | elif width > height:
122 | result = Image.new(
123 | pil_img.mode, (width, width), background_color)
124 | result.paste(pil_img, (0, (width - height) // 2))
125 | return result
126 | else:
127 | result = Image.new(
128 | pil_img.mode, (height, height), background_color)
129 | result.paste(pil_img, ((height - width) // 2, 0))
130 | return result
131 | image = expand2square(image)
132 | elif image_process_mode in ["Default", "Crop"]:
133 | pass
134 | elif image_process_mode == "Resize":
135 | image = image.resize((336, 336))
136 | else:
137 | raise ValueError(
138 | f"Invalid image_process_mode: {image_process_mode}")
139 | if max(image.size) > max_len:
140 | max_hw, min_hw = max(image.size), min(image.size)
141 | aspect_ratio = max_hw / min_hw
142 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
143 | longest_edge = int(shortest_edge * aspect_ratio)
144 | W, H = image.size
145 | if H > W:
146 | H, W = longest_edge, shortest_edge
147 | else:
148 | H, W = shortest_edge, longest_edge
149 | image = image.resize((W, H))
150 | if return_pil:
151 | return image
152 | else:
153 | buffered = BytesIO()
154 | image.save(buffered, format=image_format)
155 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
156 | return img_b64_str
157 |
158 | def get_images(self, return_pil=False):
159 | images = []
160 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
161 | if i % 2 == 0:
162 | if type(msg) is tuple:
163 | msg, image, image_process_mode = msg
164 | image = self.process_image(
165 | image, image_process_mode, return_pil=return_pil)
166 | images.append(image)
167 | return images
168 |
169 | def to_gradio_chatbot(self):
170 | ret = []
171 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
172 | if i % 2 == 0:
173 | if type(msg) is tuple:
174 | msg, image, image_process_mode = msg
175 | img_b64_str = self.process_image(
176 | image, "Default", return_pil=False,
177 | image_format='JPEG')
178 | img_str = f'
'
179 | msg = img_str + msg.replace('', '').strip()
180 | ret.append([msg, None])
181 | else:
182 | ret.append([msg, None])
183 | else:
184 | ret[-1][-1] = msg
185 | return ret
186 |
187 | def copy(self):
188 | return Conversation(
189 | system=self.system,
190 | roles=self.roles,
191 | messages=[[x, y] for x, y in self.messages],
192 | offset=self.offset,
193 | sep_style=self.sep_style,
194 | sep=self.sep,
195 | sep2=self.sep2,
196 | version=self.version)
197 |
198 | def dict(self):
199 | if len(self.get_images()) > 0:
200 | return {
201 | "system": self.system,
202 | "roles": self.roles,
203 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
204 | "offset": self.offset,
205 | "sep": self.sep,
206 | "sep2": self.sep2,
207 | }
208 | return {
209 | "system": self.system,
210 | "roles": self.roles,
211 | "messages": self.messages,
212 | "offset": self.offset,
213 | "sep": self.sep,
214 | "sep2": self.sep2,
215 | }
216 |
217 |
218 | conv_vicuna_v0 = Conversation(
219 | system="A chat between a curious human and an artificial intelligence assistant. "
220 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
221 | roles=("Human", "Assistant"),
222 | messages=(
223 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
224 | ("Assistant",
225 | "Renewable energy sources are those that can be replenished naturally in a relatively "
226 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
227 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
228 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
229 | "renewable and non-renewable energy sources:\n"
230 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
231 | "energy sources are finite and will eventually run out.\n"
232 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
233 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
234 | "and other negative effects.\n"
235 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
236 | "have lower operational costs than non-renewable sources.\n"
237 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
238 | "locations than non-renewable sources.\n"
239 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
240 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
241 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
242 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
243 | ),
244 | offset=2,
245 | sep_style=SeparatorStyle.SINGLE,
246 | sep="###",
247 | )
248 |
249 | conv_vicuna_v1 = Conversation(
250 | system="A chat between a curious user and an artificial intelligence assistant. "
251 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
252 | roles=("USER", "ASSISTANT"),
253 | version="v1",
254 | messages=(),
255 | offset=0,
256 | sep_style=SeparatorStyle.TWO,
257 | sep=" ",
258 | sep2="",
259 | )
260 |
261 | conv_llama_2 = Conversation(
262 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
263 |
264 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
265 | roles=("USER", "ASSISTANT"),
266 | version="llama_v2",
267 | messages=(),
268 | offset=0,
269 | sep_style=SeparatorStyle.LLAMA_2,
270 | sep="",
271 | sep2="",
272 | )
273 |
274 | conv_llava_llama_2 = Conversation(
275 | system="You are a helpful language and vision assistant. "
276 | "You are able to understand the visual content that the user provides, "
277 | "and assist the user with a variety of tasks using natural language.",
278 | roles=("USER", "ASSISTANT"),
279 | version="llama_v2",
280 | messages=(),
281 | offset=0,
282 | sep_style=SeparatorStyle.LLAMA_2,
283 | sep="",
284 | sep2="",
285 | )
286 |
287 | conv_mpt = Conversation(
288 | system="""<|im_start|>system
289 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
290 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
291 | version="mpt",
292 | messages=(),
293 | offset=0,
294 | sep_style=SeparatorStyle.MPT,
295 | sep="<|im_end|>",
296 | )
297 |
298 | conv_llava_plain = Conversation(
299 | system="",
300 | roles=("", ""),
301 | messages=(
302 | ),
303 | offset=0,
304 | sep_style=SeparatorStyle.PLAIN,
305 | sep="\n",
306 | )
307 |
308 | conv_llava_v0 = Conversation(
309 | system="A chat between a curious human and an artificial intelligence assistant. "
310 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
311 | roles=("Human", "Assistant"),
312 | messages=(
313 | ),
314 | offset=0,
315 | sep_style=SeparatorStyle.SINGLE,
316 | sep="###",
317 | )
318 |
319 | conv_llava_v0_mmtag = Conversation(
320 | system="A chat between a curious user and an artificial intelligence assistant. "
321 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
322 | "The visual content will be provided with the following format: visual content.",
323 | roles=("Human", "Assistant"),
324 | messages=(
325 | ),
326 | offset=0,
327 | sep_style=SeparatorStyle.SINGLE,
328 | sep="###",
329 | version="v0_mmtag",
330 | )
331 |
332 | conv_llava_v1 = Conversation(
333 | system="A chat between a curious human and an artificial intelligence assistant. "
334 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
335 | roles=("USER", "ASSISTANT"),
336 | version="v1",
337 | messages=(),
338 | offset=0,
339 | sep_style=SeparatorStyle.TWO,
340 | sep=" ",
341 | sep2="",
342 | )
343 |
344 | conv_llava_v1_mmtag = Conversation(
345 | system="A chat between a curious user and an artificial intelligence assistant. "
346 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
347 | "The visual content will be provided with the following format: visual content.",
348 | roles=("USER", "ASSISTANT"),
349 | messages=(),
350 | offset=0,
351 | sep_style=SeparatorStyle.TWO,
352 | sep=" ",
353 | sep2="",
354 | version="v1_mmtag",
355 | )
356 |
357 | default_conversation = conv_vicuna_v1
358 | conv_templates = {
359 | "default": conv_vicuna_v0,
360 | "v1": conv_vicuna_v1,
361 | "vicuna_v1": conv_vicuna_v1,
362 | "llama_2": conv_llama_2,
363 |
364 | "plain": conv_llava_plain,
365 | "v0_plain": conv_llava_plain,
366 | "llava_v0": conv_llava_v0,
367 | "v0_mmtag": conv_llava_v0_mmtag,
368 | "llava_v1": conv_llava_v1,
369 | "v1_mmtag": conv_llava_v1_mmtag,
370 | "llava_llama_2": conv_llava_llama_2,
371 |
372 | "mpt": conv_mpt
373 | }
374 |
375 |
376 | if __name__ == "__main__":
377 | print(default_conversation.get_prompt())
378 |
--------------------------------------------------------------------------------
/llava/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/data/__init__.py
--------------------------------------------------------------------------------
/llava/eval/evaluate_grounding.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import re
7 | import time
8 | from functools import partial
9 |
10 | import torch
11 | from torchvision.transforms.functional import InterpolationMode
12 | import torchvision.transforms as T
13 | from PIL import Image
14 | from torchvision.ops.boxes import box_area
15 | from tqdm import tqdm
16 | from llava.mm_utils import tokenizer_image_token, process_images
17 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
18 | from llava.conversation import conv_templates
19 | from llava.model.builder import load_pretrained_model
20 |
21 |
22 | def expand2square(pil_img, background_color):
23 | width, height = pil_img.size
24 | if width == height:
25 | return pil_img
26 | elif width > height:
27 | result = Image.new(pil_img.mode, (width, width), background_color)
28 | result.paste(pil_img, (0, (width - height) // 2))
29 | return result
30 | else:
31 | result = Image.new(pil_img.mode, (height, height), background_color)
32 | result.paste(pil_img, ((height - width) // 2, 0))
33 | return result
34 |
35 |
36 | def build_transform(is_train, input_size, pad2square=False):
37 | if is_train:
38 | transform = T.Compose([
39 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
40 | T.RandomResizedCrop(input_size, scale=(0.8, 1.0), ratio=(3. / 4., 4. / 3.),
41 | interpolation=InterpolationMode.BICUBIC),
42 | T.ToTensor(),
43 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
44 | ])
45 | else:
46 | if pad2square is False:
47 | transform = T.Compose([
48 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
49 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
50 | T.ToTensor(),
51 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
52 | ])
53 | else:
54 | transform = T.Compose([
55 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
56 | T.Lambda(lambda img: expand2square(img, tuple(int(x * 255) for x in (0.485, 0.456, 0.406)))),
57 | T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
58 | T.ToTensor(),
59 | T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
60 | ])
61 |
62 | return transform
63 |
64 |
65 | ds_collections = {
66 | 'refcoco_val': 'data/refcoco/refcoco_val.jsonl',
67 | 'refcoco_testA': 'data/refcoco/refcoco_testA.jsonl',
68 | 'refcoco_testB': 'data/refcoco/refcoco_testB.jsonl',
69 | 'refcoco+_val': 'data/refcoco/refcoco+_val.jsonl',
70 | 'refcoco+_testA': 'data/refcoco/refcoco+_testA.jsonl',
71 | 'refcoco+_testB': 'data/refcoco/refcoco+_testB.jsonl',
72 | 'refcocog_val': 'data/refcoco/refcocog_val.jsonl',
73 | 'refcocog_test': 'data/refcoco/refcocog_test.jsonl',
74 | }
75 |
76 |
77 | def reserve_square_bbox(box, w, h):
78 | if w == h:
79 | return box
80 | box = box.tolist()[0]
81 | if w > h:
82 | x1, y1, x2, y2 = box
83 | y1 -= (w - h) // 2
84 | y2 -= (w - h) // 2
85 | box = [[x1, y1, x2, y2]]
86 | return torch.tensor(box).resize(1, 4)
87 | else:
88 | x1, y1, x2, y2 = box
89 | x1 -= (h - w) // 2
90 | x2 -= (h - w) // 2
91 | box = [[x1, y1, x2, y2]]
92 | return torch.tensor(box).resize(1, 4)
93 |
94 |
95 | class ModelConfig:
96 | def __init__(self, image_aspect_ratio=None):
97 | self.image_aspect_ratio = image_aspect_ratio
98 |
99 | def box_iou(boxes1, boxes2):
100 | area1 = box_area(boxes1)
101 | area2 = box_area(boxes2)
102 |
103 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
104 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
105 |
106 | wh = (rb - lt).clamp(min=0) # [N,M,2]
107 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
108 |
109 | union = area1[:, None] + area2 - inter
110 |
111 | iou = inter / union
112 | return iou, union
113 |
114 | def collate_fn(batch, tokenzier=None):
115 | input_ids, image_tensors, bbox, hw, image_path, text = zip(*batch)
116 | input_ids = torch.stack(input_ids, dim=0)
117 | image_tensors = torch.stack(image_tensors, dim=0)
118 | return input_ids, image_tensors, bbox, hw, image_path, text
119 |
120 |
121 | class RefCOCODataset(torch.utils.data.Dataset):
122 |
123 | def __init__(self, test, prompt, input_size=224, pad2square=False, image_processor=None, model_cfg=None, tokenizer=None):
124 | self.datas = open(test).readlines()
125 | self.prompt = prompt
126 | self.transform = build_transform(is_train=False, input_size=input_size, pad2square=pad2square)
127 | self.image_processor = image_processor
128 | self.model_config = model_cfg
129 | self.tokenizer = tokenizer
130 |
131 | def __len__(self):
132 | return len(self.datas)
133 |
134 | def __getitem__(self, idx):
135 | data = json.loads(self.datas[idx].strip())
136 | image_path = data['image']
137 | text = data['sent']
138 | bbox = data['bbox']
139 |
140 | w, h = data['width'], data['height']
141 | image = os.path.join('/mnt/thuair/gcjtcl/InternVL', image_path)
142 |
143 | image = Image.open(image).convert('RGB')
144 | # pixel_values = self.transform(image).unsqueeze(0)
145 | pixel_values = process_images([image], self.image_processor, self.model_config)[0]
146 | prompt = self.prompt.format(text)
147 | prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
148 | conv = conv_templates["v1"].copy()
149 | conv.append_message(conv.roles[0], prompt)
150 | conv.append_message(conv.roles[1], None)
151 | prompt = conv.get_prompt()
152 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
153 | return input_ids, pixel_values, bbox, (h, w), image_path, text
154 |
155 |
156 | class InferenceSampler(torch.utils.data.sampler.Sampler):
157 |
158 | def __init__(self, size):
159 | self._size = int(size)
160 | assert size > 0
161 | self._rank = torch.distributed.get_rank()
162 | self._world_size = torch.distributed.get_world_size()
163 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
164 |
165 | @staticmethod
166 | def _get_local_indices(total_size, world_size, rank):
167 | shard_size = total_size // world_size
168 | left = total_size % world_size
169 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
170 |
171 | begin = sum(shard_sizes[:rank])
172 | end = min(sum(shard_sizes[:rank + 1]), total_size)
173 | return range(begin, end)
174 |
175 | def __iter__(self):
176 | yield from self._local_indices
177 |
178 | def __len__(self):
179 | return len(self._local_indices)
180 |
181 |
182 | def evaluate_chat_model():
183 | print('prompt:', prompt)
184 | random.seed(args.seed)
185 | summaries = []
186 |
187 | for ds_name in args.datasets:
188 | dataset = RefCOCODataset(
189 | test=os.path.join("/mnt/thuair/gcjtcl/InternVL", ds_collections[ds_name]),
190 | prompt=prompt,
191 | input_size=image_size,
192 | pad2square=pad2square,
193 | image_processor=image_processor,
194 | model_cfg=model_cfg,
195 | tokenizer=tokenizer
196 | )
197 | dataloader = torch.utils.data.DataLoader(
198 | dataset=dataset,
199 | sampler=InferenceSampler(len(dataset)),
200 | batch_size=args.batch_size,
201 | num_workers=args.num_workers,
202 | pin_memory=True,
203 | drop_last=False,
204 | collate_fn=partial(collate_fn),
205 | )
206 |
207 | outputs = []
208 | for _, (questions, pixel_values, bboxes, hws, image_path, text) in enumerate(tqdm(dataloader)):
209 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
210 | output_ids = model.generate(
211 | questions.to(device='cuda', non_blocking=True),
212 | images=pixel_values.to(dtype=torch.float16),
213 | do_sample=False,
214 | temperature=0,
215 | max_new_tokens=100)
216 |
217 | pred = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
218 | answers = [pred]
219 |
220 | for bbox, hw, answer in zip(bboxes, hws, answers):
221 | outputs.append({
222 | 'image_path': image_path,
223 | 'text': text,
224 | 'answer': answer,
225 | 'gt_bbox': bbox,
226 | 'hw': hw,
227 | })
228 |
229 | torch.distributed.barrier()
230 |
231 | world_size = torch.distributed.get_world_size()
232 | merged_outputs = [None for _ in range(world_size)]
233 | torch.distributed.all_gather_object(merged_outputs, outputs)
234 |
235 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
236 |
237 | if torch.distributed.get_rank() == 0:
238 | print(f'Evaluating {ds_name} ...')
239 | time_prefix = time.strftime('%y%m%d%H%M%S', time.localtime())
240 | results_file = f'{ds_name}_{time_prefix}.json'
241 | results_file = os.path.join(args.out_dir, results_file)
242 | json.dump(merged_outputs, open(results_file, 'w'))
243 | # with open("/mnt/thuair/gcjtcl/InternVL/internvl_chat/conv768/refcocog_val_240419174233.json", 'r') as f:
244 | # merged_outputs = json.load(f)
245 |
246 | correct = total_cnt = 0
247 | for i, output in enumerate(merged_outputs):
248 | predict_bbox = re.findall(PATTERN, output['answer'])
249 | try:
250 | predict_bbox = (float(predict_bbox[0][0]), float(predict_bbox[0][1]), float(predict_bbox[0][2]),
251 | float(predict_bbox[0][3]))
252 | except:
253 | predict_bbox = (0., 0., 0., 0.)
254 | target_bbox = torch.tensor(output['gt_bbox'],
255 | dtype=torch.float32).view(-1, 4)
256 | predict_bbox = torch.tensor(predict_bbox,
257 | dtype=torch.float32).view(-1, 4)
258 | if predict_bbox.sum() >= 4:
259 | predict_bbox = predict_bbox / 1000
260 | predict_bbox *= max(output['hw'])
261 | w, h = output['hw'][1], output['hw'][0]
262 | predict_bbox = reserve_square_bbox(predict_bbox, w, h)
263 | # print(predict_bbox)
264 | # predict_bbox[:, 0::2] *= output['hw'][1]
265 | # predict_bbox[:, 1::2] *= output['hw'][0]
266 | iou, _ = box_iou(predict_bbox, target_bbox)
267 | iou = iou.item()
268 | total_cnt += 1
269 | if iou >= 0.5:
270 | correct += 1
271 |
272 | print(f'Evaluating {ds_name} ...')
273 | print(f'Precision @ 1: {correct / total_cnt} \n')
274 | summaries.append([args.checkpoint, ds_name, f'Precision @ 1: {correct / total_cnt} \n'])
275 |
276 | torch.distributed.barrier()
277 |
278 | out_path = '_'.join(args.checkpoint.split('/')[-2:])
279 | writer = open(os.path.join(args.out_dir, f'{out_path}.txt'), 'a')
280 | print(f"write results to file {os.path.join(args.out_dir, f'{out_path}.txt')}")
281 | for summary in summaries:
282 | print(summary)
283 | writer.write(f'{summary}\n')
284 | writer.close()
285 |
286 |
287 | if __name__ == '__main__':
288 |
289 | parser = argparse.ArgumentParser()
290 | parser.add_argument('--checkpoint', type=str, default='')
291 | parser.add_argument('--datasets', type=str, default='refcoco_val,refcoco_testA,refcoco_testB,'
292 | 'refcoco+_val,refcoco+_testA,refcoco+_testB,'
293 | 'refcocog_val,refcocog_test')
294 | parser.add_argument('--batch-size', type=int, default=1)
295 | parser.add_argument('--num-workers', type=int, default=1)
296 | parser.add_argument('--num-beams', type=int, default=5)
297 | parser.add_argument('--out-dir', type=str, default='results')
298 | parser.add_argument('--sample', type=bool, default=False)
299 | parser.add_argument('--temperature', type=float, default=0.0)
300 | parser.add_argument('--seed', type=int, default=0)
301 | args = parser.parse_args()
302 |
303 |
304 | args.datasets = args.datasets.split(',')
305 | print('datasets:', args.datasets)
306 | assert args.batch_size == 1, 'Only batch size 1 is supported'
307 |
308 | torch.distributed.init_process_group(
309 | backend='nccl',
310 | world_size=int(os.getenv('WORLD_SIZE', '1')),
311 | rank=int(os.getenv('RANK', '0')),
312 | )
313 | rank = torch.distributed.get_rank()
314 | if rank == 0:
315 | if not os.path.exists(args.out_dir):
316 | os.makedirs(args.out_dir)
317 |
318 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
319 | torch.cuda.set_device(int(os.getenv('RANK', 0)))
320 | print(f"rank: {int(os.getenv('RANK', 0))}")
321 | device = torch.device(int(os.getenv('RANK', 0)))
322 |
323 | # tokenizer = LlamaTokenizer.from_pretrained(args.checkpoint)
324 | PATTERN = re.compile(r'\[*\[(.*?),(.*?),(.*?),(.*?)\]\]*')
325 |
326 | # 创建一个model_cfg对象
327 | model_cfg = ModelConfig(image_aspect_ratio="pad")
328 |
329 |
330 | # device =
331 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.checkpoint, None, "llava", device='cpu', device_map='cpu')
332 | model = model.to(device).eval()
333 | vision_tower = model.get_vision_tower().to(device)
334 | model.get_model().mm_projector.to(device)
335 | # model.cuda()
336 | # for p in model.parameters():
337 | # p.cuda()
338 | image_size = 336
339 | pad2square = True
340 | prompt = 'Please provide the bounding box coordinate of the region this sentence describes: {}.'
341 |
342 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
343 | if total_params > 30:
344 | args.num_beams = 1
345 | print(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
346 | else:
347 | print(f'[test] total_params: {total_params}B')
348 | print(f'[test] image_size: {image_size}')
349 | print(f'[test] pad2square: {pad2square}')
350 | print(f'[test] template: v1')
351 |
352 | evaluate_chat_model()
353 |
--------------------------------------------------------------------------------
/llava/eval/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | from vlmeval.smp import *
4 | from vlmeval.evaluate import *
5 | from vlmeval.inference import infer_data_job
6 | from vlmeval.config import supported_VLM
7 | from vlmeval.utils import dataset_URLs, DATASET_TYPE, abbr2full, MMMU_result_transfer
8 | from functools import partial
9 | from vlmeval.vlm import LLaVA
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--data', type=str, nargs='+', required=True)
15 | parser.add_argument('--model', type=str, nargs='+', required=True)
16 | parser.add_argument('--work-dir', type=str, default='.', help='select the output directory')
17 | parser.add_argument('--mode', type=str, default='all', choices=['all', 'infer'])
18 | parser.add_argument('--nproc', type=int, default=4, help='Parallel API calling')
19 | parser.add_argument('--retry', type=int, default=None, help='retry numbers for API VLMs')
20 | parser.add_argument('--judge', type=str, default=None)
21 | parser.add_argument('--ignore', action='store_true', help='Ignore failed indices. ')
22 | parser.add_argument('--verbose', action='store_true')
23 | parser.add_argument("--llava-path", type=str,
24 | default='liuhaotian/llava-v1.5-7b')
25 | args = parser.parse_args()
26 | return args
27 |
28 |
29 | def main():
30 | logger = get_logger('RUN')
31 |
32 | args = parse_args()
33 | assert len(args.data), "--data should be a list of data files"
34 |
35 | supported_VLM.update(
36 | {"llava_v1.5_7b": partial(LLaVA, model_pth=args.llava_path)})
37 |
38 | if args.retry is not None:
39 | for k, v in supported_VLM.items():
40 | if hasattr(v, 'keywords') and 'retry' in v.keywords:
41 | v.keywords['retry'] = args.retry
42 | supported_VLM[k] = v
43 | if hasattr(v, 'keywords') and 'verbose' in v.keywords:
44 | v.keywords['verbose'] = args.verbose
45 | supported_VLM[k] = v
46 |
47 | rank, world_size = get_rank_and_world_size()
48 | if world_size > 1:
49 | torch.cuda.set_device(rank)
50 | dist.init_process_group(backend='nccl', timeout=datetime.timedelta(seconds=10800))
51 |
52 | for _, model_name in enumerate(args.model):
53 | model = None
54 |
55 | pred_root = osp.join(args.work_dir, model_name)
56 | os.makedirs(pred_root, exist_ok=True)
57 |
58 | for _, dataset_name in enumerate(args.data):
59 | custom_flag = False
60 |
61 | if dataset_name not in dataset_URLs:
62 | dataset_name = abbr2full(dataset_name)
63 |
64 | if dataset_name not in dataset_URLs:
65 | logger.warning(f'Dataset {dataset_name} is not officially supported. ')
66 | file_path = osp.join(LMUDataRoot(), f'{dataset_name}.tsv')
67 | if not osp.exists(file_path):
68 | logger.error(f'Cannot find the local dataset {dataset_name}. ')
69 | continue
70 | else:
71 | custom_flag = True
72 |
73 | result_file = f'{pred_root}/{model_name}_{dataset_name}.xlsx'
74 |
75 | if model is None:
76 | model = model_name # which is only a name
77 |
78 | model = infer_data_job(
79 | model,
80 | work_dir=pred_root,
81 | model_name=model_name,
82 | dataset_name=dataset_name,
83 | verbose=args.verbose,
84 | api_nproc=args.nproc,
85 | ignore_failed=args.ignore)
86 |
87 | if rank == 0:
88 | if dataset_name in ['MMMU_TEST']:
89 | result_json = MMMU_result_transfer(result_file)
90 | logger.info(f'Transfer MMMU_TEST result to json for official evaluation, json file saved in {result_json}') # noqa: E501
91 |
92 | if dataset_name in ['MMBench_TEST_CN', 'MMBench_TEST_EN', 'MMMU_TEST']:
93 | if not MMBenchOfficialServer():
94 | logger.error(
95 | f'Can not evaluate {dataset_name} on non-official servers, '
96 | 'will skip the evaluation. '
97 | )
98 | continue
99 |
100 | judge_kwargs = {
101 | 'nproc': args.nproc,
102 | 'verbose': args.verbose,
103 | }
104 | if args.retry is not None:
105 | judge_kwargs['retry'] = args.retry
106 | if args.judge is not None:
107 | judge_kwargs['model'] = args.judge
108 | else:
109 | if DATASET_TYPE(dataset_name) in ['multi-choice', 'Y/N']:
110 | judge_kwargs['model'] = 'chatgpt-0613'
111 | elif listinstr(['MMVet', 'MathVista', 'LLaVABench'], dataset_name):
112 | judge_kwargs['model'] = 'gpt-4-turbo'
113 | if 'OPENAI_API_KEY_JUDGE' in os.environ and len(os.environ['OPENAI_API_KEY_JUDGE']):
114 | judge_kwargs['key'] = os.environ['OPENAI_API_KEY_JUDGE']
115 | if 'OPENAI_API_BASE_JUDGE' in os.environ and len(os.environ['OPENAI_API_BASE_JUDGE']):
116 | judge_kwargs['api_base'] = os.environ['OPENAI_API_BASE_JUDGE']
117 |
118 | if rank == 0 and args.mode == 'all':
119 | if DATASET_TYPE(dataset_name) == 'multi-choice':
120 | dataset_name = 'default' if custom_flag else dataset_name
121 | multiple_choice_eval(
122 | result_file,
123 | dataset=dataset_name,
124 | **judge_kwargs)
125 |
126 | elif DATASET_TYPE(dataset_name) == 'Y/N':
127 | YOrN_eval(
128 | result_file,
129 | dataset=dataset_name,
130 | **judge_kwargs)
131 |
132 | elif DATASET_TYPE(dataset_name) == 'Caption':
133 | COCO_eval(result_file)
134 | elif dataset_name == 'MMVet':
135 | MMVet_eval(result_file, **judge_kwargs)
136 | elif dataset_name == 'OCRBench':
137 | OCRBench_eval(result_file)
138 | elif listinstr(['OCRVQA', 'TextVQA', 'ChartQA', 'DocVQA', 'InfoVQA'], dataset_name):
139 | VQAEval(result_file, dataset_name)
140 | elif listinstr(['MathVista'], dataset_name):
141 | MathVista_eval(result_file, **judge_kwargs)
142 | elif listinstr(['LLaVABench'], dataset_name):
143 | LLaVABench_eval(result_file, **judge_kwargs)
144 | else:
145 | logger.error(f'Dataset {dataset_name} is not handled by evaluator, will be skipped. ')
146 |
147 |
148 | if __name__ == '__main__':
149 | load_env()
150 | main()
151 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 | import torch
5 | import math
6 | import ast
7 |
8 | from transformers import StoppingCriteria
9 | from llava.constants import IMAGE_TOKEN_INDEX
10 |
11 |
12 | def select_best_resolution(original_size, possible_resolutions):
13 | """
14 | Selects the best resolution from a list of possible resolutions based on the original size.
15 |
16 | Args:
17 | original_size (tuple): The original size of the image in the format (width, height).
18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19 |
20 | Returns:
21 | tuple: The best fit resolution in the format (width, height).
22 | """
23 | original_width, original_height = original_size
24 | best_fit = None
25 | max_effective_resolution = 0
26 | min_wasted_resolution = float('inf')
27 |
28 | for width, height in possible_resolutions:
29 | scale = min(width / original_width, height / original_height)
30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32 | wasted_resolution = (width * height) - effective_resolution
33 |
34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35 | max_effective_resolution = effective_resolution
36 | min_wasted_resolution = wasted_resolution
37 | best_fit = (width, height)
38 |
39 | return best_fit
40 |
41 |
42 | def resize_and_pad_image(image, target_resolution):
43 | """
44 | Resize and pad an image to a target resolution while maintaining aspect ratio.
45 |
46 | Args:
47 | image (PIL.Image.Image): The input image.
48 | target_resolution (tuple): The target resolution (width, height) of the image.
49 |
50 | Returns:
51 | PIL.Image.Image: The resized and padded image.
52 | """
53 | original_width, original_height = image.size
54 | target_width, target_height = target_resolution
55 |
56 | scale_w = target_width / original_width
57 | scale_h = target_height / original_height
58 |
59 | if scale_w < scale_h:
60 | new_width = target_width
61 | new_height = min(math.ceil(original_height * scale_w), target_height)
62 | else:
63 | new_height = target_height
64 | new_width = min(math.ceil(original_width * scale_h), target_width)
65 |
66 | # Resize the image
67 | resized_image = image.resize((new_width, new_height))
68 |
69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70 | paste_x = (target_width - new_width) // 2
71 | paste_y = (target_height - new_height) // 2
72 | new_image.paste(resized_image, (paste_x, paste_y))
73 |
74 | return new_image
75 |
76 |
77 | def divide_to_patches(image, patch_size):
78 | """
79 | Divides an image into patches of a specified size.
80 |
81 | Args:
82 | image (PIL.Image.Image): The input image.
83 | patch_size (int): The size of each patch.
84 |
85 | Returns:
86 | list: A list of PIL.Image.Image objects representing the patches.
87 | """
88 | patches = []
89 | width, height = image.size
90 | for i in range(0, height, patch_size):
91 | for j in range(0, width, patch_size):
92 | box = (j, i, j + patch_size, i + patch_size)
93 | patch = image.crop(box)
94 | patches.append(patch)
95 |
96 | return patches
97 |
98 |
99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100 | """
101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102 |
103 | Args:
104 | image_size (tuple): The size of the input image in the format (width, height).
105 | grid_pinpoints (str): A string representation of a list of possible resolutions.
106 | patch_size (int): The size of each image patch.
107 |
108 | Returns:
109 | tuple: The shape of the image patch grid in the format (width, height).
110 | """
111 | if type(grid_pinpoints) is list:
112 | possible_resolutions = grid_pinpoints
113 | else:
114 | possible_resolutions = ast.literal_eval(grid_pinpoints)
115 | width, height = select_best_resolution(image_size, possible_resolutions)
116 | return width // patch_size, height // patch_size
117 |
118 |
119 | def process_anyres_image(image, processor, grid_pinpoints):
120 | """
121 | Process an image with variable resolutions.
122 |
123 | Args:
124 | image (PIL.Image.Image): The input image to be processed.
125 | processor: The image processor object.
126 | grid_pinpoints (str): A string representation of a list of possible resolutions.
127 |
128 | Returns:
129 | torch.Tensor: A tensor containing the processed image patches.
130 | """
131 | if type(grid_pinpoints) is list:
132 | possible_resolutions = grid_pinpoints
133 | else:
134 | possible_resolutions = ast.literal_eval(grid_pinpoints)
135 | best_resolution = select_best_resolution(image.size, possible_resolutions)
136 | image_padded = resize_and_pad_image(image, best_resolution)
137 |
138 | patches = divide_to_patches(image_padded, processor.crop_size['height'])
139 |
140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141 |
142 | image_patches = [image_original_resize] + patches
143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144 | for image_patch in image_patches]
145 | return torch.stack(image_patches, dim=0)
146 |
147 |
148 | def load_image_from_base64(image):
149 | return Image.open(BytesIO(base64.b64decode(image)))
150 |
151 |
152 | def expand2square(pil_img, background_color):
153 | width, height = pil_img.size
154 | if width == height:
155 | return pil_img
156 | elif width > height:
157 | result = Image.new(pil_img.mode, (width, width), background_color)
158 | result.paste(pil_img, (0, (width - height) // 2))
159 | return result
160 | else:
161 | result = Image.new(pil_img.mode, (height, height), background_color)
162 | result.paste(pil_img, ((height - width) // 2, 0))
163 | return result
164 |
165 |
166 | def process_images(images, image_processor, model_cfg):
167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168 | new_images = []
169 | if image_aspect_ratio == 'pad':
170 | for image in images:
171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173 | new_images.append(image)
174 | elif image_aspect_ratio == "anyres":
175 | for image in images:
176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177 | new_images.append(image)
178 | else:
179 | return image_processor(images, return_tensors='pt')['pixel_values']
180 | if all(x.shape == new_images[0].shape for x in new_images):
181 | new_images = torch.stack(new_images, dim=0)
182 | return new_images
183 |
184 |
185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
187 |
188 | def insert_separator(X, sep):
189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190 |
191 | input_ids = []
192 | offset = 0
193 | # prompt_chunks[0] 是第一个chunk,如果图在最前面,len(promt_chunks[0]) =0
194 | # 如果不是0而且是[bos]说明可能是交错数据
195 | # 正常的image text pair会是offset=0
196 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
197 | offset = 1
198 | input_ids.append(prompt_chunks[0][0])
199 |
200 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
201 | input_ids.extend(x[offset:])
202 |
203 | if return_tensors is not None:
204 | if return_tensors == 'pt':
205 | return torch.tensor(input_ids, dtype=torch.long)
206 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
207 | return input_ids
208 |
209 |
210 | def get_model_name_from_path(model_path):
211 | model_path = model_path.strip("/")
212 | model_paths = model_path.split("/")
213 | if model_paths[-1].startswith('checkpoint-'):
214 | return model_paths[-2] + "_" + model_paths[-1]
215 | else:
216 | return model_paths[-1]
217 |
218 | class KeywordsStoppingCriteria(StoppingCriteria):
219 | def __init__(self, keywords, tokenizer, input_ids):
220 | self.keywords = keywords
221 | self.keyword_ids = []
222 | self.max_keyword_len = 0
223 | for keyword in keywords:
224 | cur_keyword_ids = tokenizer(keyword).input_ids
225 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
226 | cur_keyword_ids = cur_keyword_ids[1:]
227 | if len(cur_keyword_ids) > self.max_keyword_len:
228 | self.max_keyword_len = len(cur_keyword_ids)
229 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
230 | self.tokenizer = tokenizer
231 | self.start_len = input_ids.shape[1]
232 |
233 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
234 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
235 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
236 | for keyword_id in self.keyword_ids:
237 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
238 | if torch.equal(truncated_output_ids, keyword_id):
239 | return True
240 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
241 | for keyword in self.keywords:
242 | if keyword in outputs:
243 | return True
244 | return False
245 |
246 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
247 | outputs = []
248 | for i in range(output_ids.shape[0]):
249 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
250 | return all(outputs)
251 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 |
--------------------------------------------------------------------------------
/llava/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 | import subprocess
20 |
21 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
22 | import torch
23 | from llava.model import *
24 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
25 |
26 |
27 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", **kwargs):
28 | kwargs = {"device_map": device_map, **kwargs}
29 |
30 | if device != "cuda":
31 | kwargs['device_map'] = {"": device}
32 |
33 | if load_8bit:
34 | kwargs['load_in_8bit'] = True
35 | elif load_4bit:
36 | kwargs['load_in_4bit'] = True
37 | kwargs['quantization_config'] = BitsAndBytesConfig(
38 | load_in_4bit=True,
39 | bnb_4bit_compute_dtype=torch.float16,
40 | bnb_4bit_use_double_quant=True,
41 | bnb_4bit_quant_type='nf4'
42 | )
43 | else:
44 | kwargs['torch_dtype'] = torch.float16
45 |
46 | tokenizer = AutoTokenizer.from_pretrained(
47 | model_path, use_fast=False)
48 | model = LlavaLlamaForCausalLM.from_pretrained(
49 | model_path, low_cpu_mem_usage=True, **kwargs)
50 |
51 | image_processor = None
52 |
53 | mm_use_im_start_end = getattr(
54 | model.config, "mm_use_im_start_end", False)
55 | mm_use_im_patch_token = getattr(
56 | model.config, "mm_use_im_patch_token", True)
57 | if mm_use_im_patch_token:
58 | tokenizer.add_tokens(
59 | [DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
60 | if mm_use_im_start_end:
61 | tokenizer.add_tokens(
62 | [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
63 | model.resize_token_embeddings(len(tokenizer))
64 |
65 | vision_tower = model.get_vision_tower()
66 | if not vision_tower.is_loaded:
67 | vision_tower.load_model()
68 | print(model.get_vision_tower().is_loaded)
69 | vision_tower.to(device=device, dtype=torch.float16)
70 | image_processor = vision_tower.image_processor
71 |
72 | if hasattr(model.config, "max_sequence_length"):
73 | context_len = model.config.max_sequence_length
74 | else:
75 | context_len = 2048
76 |
77 | return tokenizer, model, image_processor, context_len
78 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(LlamaConfig):
31 | model_type = "llava_llama"
32 |
33 |
34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(LlavaLlamaModel, self).__init__(config)
39 |
40 |
41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = LlavaLlamaModel(config)
47 | self.pretraining_tp = config.pretraining_tp
48 | self.vocab_size = config.vocab_size
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def print_trainable_parameters(self):
58 | for i, layer in enumerate(self.model.layers):
59 | print(f"LLM Layer {i}")
60 | is_trainable = any(
61 | param.requires_grad for param in layer.parameters())
62 | print(f"LLM Layer {i} is Trainable: {is_trainable}")
63 | for name, param in self.named_parameters():
64 | # if param.requires_grad:
65 | print(name)
66 |
67 | def forward(
68 | self,
69 | input_ids: torch.LongTensor = None,
70 | attention_mask: Optional[torch.Tensor] = None,
71 | position_ids: Optional[torch.LongTensor] = None,
72 | past_key_values: Optional[List[torch.FloatTensor]] = None,
73 | inputs_embeds: Optional[torch.FloatTensor] = None,
74 | labels: Optional[torch.LongTensor] = None,
75 | use_cache: Optional[bool] = None,
76 | output_attentions: Optional[bool] = None,
77 | output_hidden_states: Optional[bool] = None,
78 | images: Optional[torch.FloatTensor] = None,
79 | image_sizes: Optional[List[List[int]]] = None,
80 | return_dict: Optional[bool] = None,
81 | cache_position: Optional[torch.LongTensor] = None,
82 | ) -> Union[Tuple, CausalLMOutputWithPast]:
83 |
84 | if inputs_embeds is None:
85 | (
86 | input_ids,
87 | position_ids,
88 | attention_mask,
89 | past_key_values,
90 | inputs_embeds,
91 | labels
92 | ) = self.prepare_inputs_labels_for_multimodal(
93 | input_ids,
94 | position_ids,
95 | attention_mask,
96 | past_key_values,
97 | labels,
98 | images,
99 | image_sizes
100 | )
101 |
102 | return super().forward(
103 | input_ids=input_ids,
104 | attention_mask=attention_mask,
105 | position_ids=position_ids,
106 | past_key_values=past_key_values,
107 | inputs_embeds=inputs_embeds,
108 | labels=labels,
109 | use_cache=use_cache,
110 | output_attentions=output_attentions,
111 | output_hidden_states=output_hidden_states,
112 | return_dict=return_dict
113 | )
114 |
115 | @torch.no_grad()
116 | def generate(
117 | self,
118 | inputs: Optional[torch.Tensor] = None,
119 | images: Optional[torch.Tensor] = None,
120 | image_sizes: Optional[torch.Tensor] = None,
121 | **kwargs,
122 | ) -> Union[GenerateOutput, torch.LongTensor]:
123 | position_ids = kwargs.pop("position_ids", None)
124 | attention_mask = kwargs.pop("attention_mask", None)
125 | if "inputs_embeds" in kwargs:
126 | raise NotImplementedError("`inputs_embeds` is not supported")
127 |
128 | if images is not None:
129 | (
130 | inputs,
131 | position_ids,
132 | attention_mask,
133 | _,
134 | inputs_embeds,
135 | _
136 | ) = self.prepare_inputs_labels_for_multimodal(
137 | inputs,
138 | position_ids,
139 | attention_mask,
140 | None,
141 | None,
142 | images,
143 | image_sizes=image_sizes
144 | )
145 | else:
146 | inputs_embeds = self.get_model().embed_tokens(inputs)
147 |
148 | return super().generate(
149 | position_ids=position_ids,
150 | attention_mask=attention_mask,
151 | inputs_embeds=inputs_embeds,
152 | **kwargs
153 | )
154 |
155 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
156 | inputs_embeds=None, **kwargs):
157 | images = kwargs.pop("images", None)
158 | image_sizes = kwargs.pop("image_sizes", None)
159 | inputs = super().prepare_inputs_for_generation(
160 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
161 | )
162 | if images is not None:
163 | inputs['images'] = images
164 | if image_sizes is not None:
165 | inputs['image_sizes'] = image_sizes
166 | return inputs
167 |
168 | AutoConfig.register("llava_llama", LlavaConfig)
169 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
170 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 | from .convnext_encoder import ConvNeXtCLIPVisionTower
4 | from .siglip_encoder import SiglipVisionTower
5 | from .lknet_encoder import LKNetCLIPVisionTower
6 |
7 |
8 | def build_vision_tower(vision_tower_cfg, **kwargs):
9 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
10 | print(f"now we are building vision tower, the model is {vision_tower}")
11 | if 'siglip' in vision_tower:
12 | print(f'building SiglipVisionTower')
13 | return SiglipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
14 | if vision_tower.startswith("openai") or 'clip-vit' in vision_tower:
15 | print(f'building CLIPVisionTower')
16 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
17 | if 'convnext' in vision_tower:
18 | print(f'building ConvNeXtCLIPVisionTower')
19 | return ConvNeXtCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
20 | if 'lknet' in vision_tower.lower():
21 | print(f'building lknet')
22 | return LKNetCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
23 | return ConvNeXtCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
24 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(
16 | args, 'mm_vision_select_feature', 'patch')
17 |
18 | if not delay_load:
19 | self.load_model()
20 | else:
21 | self.cfg_only = CLIPVisionConfig.from_pretrained(
22 | self.vision_tower_name)
23 |
24 | def load_model(self):
25 | self.image_processor = CLIPImageProcessor.from_pretrained(
26 | self.vision_tower_name)
27 | self.vision_tower = CLIPVisionModel.from_pretrained(
28 | self.vision_tower_name)
29 | self.vision_tower.requires_grad_(False)
30 |
31 | self.is_loaded = True
32 |
33 | def print_trainable_parameters(self):
34 | for i, layer in enumerate(self.vision_tower.vision_model.encoder.layers):
35 | is_trainable = any(
36 | param.requires_grad for param in layer.parameters())
37 | print(f"ViT Layer {i} is trainable: {is_trainable}")
38 |
39 | def feature_select(self, image_forward_outs):
40 | image_features = image_forward_outs.hidden_states[self.select_layer]
41 | if self.select_feature == 'patch':
42 | image_features = image_features[:, 1:]
43 | elif self.select_feature == 'cls_patch':
44 | image_features = image_features
45 | else:
46 | raise ValueError(
47 | f'Unexpected select feature: {self.select_feature}')
48 | return image_features
49 |
50 | def forward(self, images):
51 | if type(images) is list:
52 | image_features = []
53 | for image in images:
54 | image_forward_out = self.vision_tower(image.to(
55 | device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
56 | image_feature = self.feature_select(
57 | image_forward_out).to(image.dtype)
58 | image_features.append(image_feature)
59 | else:
60 | image_forward_outs = self.vision_tower(
61 | images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
62 | image_features = self.feature_select(
63 | image_forward_outs).to(images.dtype)
64 |
65 | return image_features
66 |
67 | @property
68 | def dummy_feature(self):
69 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
70 |
71 | @property
72 | def dtype(self):
73 | return self.vision_tower.dtype
74 |
75 | @property
76 | def device(self):
77 | return self.vision_tower.device
78 |
79 | @property
80 | def config(self):
81 | if self.is_loaded:
82 | return self.vision_tower.config
83 | else:
84 | return self.cfg_only
85 |
86 | @property
87 | def hidden_size(self):
88 | return self.config.hidden_size
89 |
90 | @property
91 | def num_patches_per_side(self):
92 | return self.config.image_size // self.config.patch_size
93 |
94 | @property
95 | def num_patches(self):
96 | return (self.config.image_size // self.config.patch_size) ** 2
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/convnext_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | from transformers import CLIPImageProcessor
5 | from transformers import ConvNextModel, ConvNextConfig
6 | from transformers.models.convnext.modeling_convnext import ConvNextStage
7 |
8 |
9 | class ConvNeXtCLIPVisionTower(nn.Module):
10 | def __init__(self, vision_tower, args, delay_load=False):
11 | super().__init__()
12 |
13 | self.is_loaded = False
14 |
15 | self.vision_tower_name = vision_tower
16 | self.select_layer = args.mm_vision_select_layer
17 | self.update_resolution = getattr(
18 | args, 'mm_vision_resolution', 256)
19 | self.vision_add_five_stage = getattr(args, 'vision_add_five_stage', 0)
20 | self.vision_five_stage_width = getattr(args, 'vision_five_stage_width', 1536)
21 |
22 | if not delay_load:
23 | self.load_model()
24 | else:
25 | print(f"deloy_load vision tower is: {self.vision_tower_name}")
26 | self.cfg_only = ConvNextConfig.from_pretrained(
27 | self.vision_tower_name)
28 |
29 | def load_model(self):
30 | print(f"entering load model, load {self.vision_tower_name}")
31 | self.image_processor = CLIPImageProcessor.from_pretrained(
32 | self.vision_tower_name)
33 | self.vision_tower = ConvNextModel.from_pretrained(
34 | self.vision_tower_name)
35 | self.vision_tower.requires_grad_(False)
36 | self.is_loaded = True
37 |
38 | if self.select_layer == -2:
39 | self.select_layer = -1
40 | self.vision_tower.encoder.stages[-1].layers.pop(-1)
41 | print(
42 | f'Last block removed, select layer changed to {self.select_layer}')
43 |
44 | if self.update_resolution > 256:
45 | self.set_crop_size(self.update_resolution)
46 | print(
47 | f'Crop size changed to {self.update_resolution}x{self.update_resolution}')
48 |
49 | if self.vision_add_five_stage != 0:
50 | self.add_stage(self.vision_add_five_stage, self.vision_five_stage_width)
51 | print(
52 | f'Added stage with width {self.vision_five_stage_width}')
53 |
54 | def forward(self, images):
55 | if type(images) is list:
56 | image_features = []
57 | for image in images:
58 | # Get the embeddings of the image
59 | embedding_output = self.vision_tower.embeddings(image.unsqueeze(0))
60 |
61 | # Get the image features
62 | image_feature = self.vision_tower.encoder(embedding_output,
63 | output_hidden_states=True,
64 | return_dict=True)
65 | image_feature = image_feature.hidden_states[-1].permute(0, 2, 3, 1)
66 | image_feature = image_feature.reshape(image_features.shape[0], -1, image_features.shape[3]).to(images.dtype)
67 |
68 | image_features.append(image_feature)
69 | else:
70 | embedding_output = self.vision_tower.embeddings(images)
71 | image_features = self.vision_tower.encoder(embedding_output,
72 | output_hidden_states=True,
73 | return_dict=True)
74 | image_features = image_features.hidden_states[-1].permute(0, 2, 3, 1)
75 | image_features = image_features.reshape(image_features.shape[0], -1, image_features.shape[3]).to(images.dtype)
76 |
77 | return image_features
78 |
79 | def make_layers_trainable_after_stage(self, stage_index, layer_index=0):
80 | for i, stage in enumerate(self.vision_tower.encoder.stages):
81 | if i == stage_index:
82 | if layer_index == 0:
83 | stage.downsampling_layer.requires_grad_(True)
84 | for idx, layer in enumerate(stage.layers):
85 | if idx >= layer_index:
86 | for param in layer.parameters():
87 | param.requires_grad = True
88 | if i > stage_index:
89 | stage.downsampling_layer.requires_grad_(True)
90 | for layer in stage.layers:
91 | for param in layer.parameters():
92 | param.requires_grad = True
93 |
94 | def set_crop_size(self, new_size):
95 | size_dict = {'height': new_size, 'width': new_size}
96 | self.image_processor.crop_size = size_dict
97 | self.image_processor.size = {"shortest_edge": new_size}
98 | self.vision_tower.config.image_size = new_size
99 |
100 | def add_stage(self, depths=3, hidden_dims=3072):
101 | self.vision_tower.encoder.stages.append(ConvNextStage(self.config, self.hidden_size, hidden_dims, depth=depths))
102 | self.vision_tower.config.depths.append(depths)
103 | self.vision_tower.config.hidden_sizes.append(hidden_dims)
104 | self.vision_tower.config.stage_names.append('stage5')
105 | self.vision_tower.config.out_features = ['stage5']
106 | self.vision_tower.config.out_indices = [5]
107 | self.vision_tower.config.num_stages += 1
108 | self.vision_tower.config._name_or_path = ''
109 |
110 | def save_config(self, path):
111 | self.vision_tower.config.save_pretrained(path)
112 |
113 | @property
114 | def dummy_feature(self):
115 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
116 |
117 | @property
118 | def dtype(self):
119 | return self.vision_tower.dtype
120 |
121 | @property
122 | def device(self):
123 | return self.vision_tower.device
124 |
125 | @property
126 | def config(self):
127 | if self.is_loaded:
128 | return self.vision_tower.config
129 | else:
130 | return self.cfg_only
131 |
132 | @property
133 | def hidden_size(self):
134 | return self.config.hidden_sizes[-1]
135 |
136 | @property
137 | def num_patches_per_side(self):
138 | return self.config.image_size // self.config.patch_size
139 |
140 | @property
141 | def num_patches(self):
142 | return (self.config.image_size // 32) ** 2
143 |
144 | @property
145 | def crop_size(self):
146 | return self.image_processor.crop_size
147 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/lknet_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 | from transformers import CLIPImageProcessor
5 | from transformers import PretrainedConfig
6 | import os
7 | from .unireplknet.unireplknet_encoder import unireplknet_l_plus
8 |
9 |
10 | class LKNetConfig(PretrainedConfig):
11 | model_type = "lknet"
12 |
13 | def __init__(
14 | self,
15 | in_chans=3,
16 | image_size=256,
17 | num_classes=1000,
18 | depths=(3, 3, 27, 3, 6),
19 | dims=(192, 384, 768, 1536, 3072),
20 | drop_path_rate=0.0,
21 | layer_scale_init_value=1e-6,
22 | head_init_scale=1.0,
23 | kernel_sizes=None,
24 | deploy=False,
25 | with_cp=True,
26 | init_cfg=None,
27 | attempt_use_lk_impl=True,
28 | use_sync_bn=False,
29 | **kwargs,
30 | ):
31 | super().__init__(**kwargs)
32 | self.in_chans = in_chans
33 | self.image_size = image_size
34 | self.num_classes = num_classes
35 | self.depths = depths
36 | self.dims = dims
37 | self.drop_path_rate = drop_path_rate
38 | self.layer_scale_init_value = layer_scale_init_value
39 | self.head_init_scale = head_init_scale
40 | self.kernel_sizes = kernel_sizes
41 | self.deploy = deploy
42 | self.with_cp = with_cp
43 | self.init_cfg = init_cfg
44 | self.attempt_use_lk_impl = attempt_use_lk_impl
45 | self.use_sync_bn = use_sync_bn
46 |
47 |
48 | unireplknet_l_plus_config = {
49 | "depths": (3, 3, 27, 3, 6),
50 | "kernel_sizes": (
51 | (3, 3, 3),
52 | (13, 13, 13),
53 | (
54 | 13,
55 | 3,
56 | 3,
57 | 13,
58 | 3,
59 | 3,
60 | 13,
61 | 3,
62 | 3,
63 | 13,
64 | 3,
65 | 3,
66 | 13,
67 | 3,
68 | 3,
69 | 13,
70 | 3,
71 | 3,
72 | 13,
73 | 3,
74 | 3,
75 | 13,
76 | 3,
77 | 3,
78 | 13,
79 | 3,
80 | 3,
81 | ),
82 | (13, 13, 13),
83 | (3, 3, 3, 3, 3, 3),
84 | ),
85 | }
86 |
87 |
88 | class LKNetCLIPVisionTower(nn.Module):
89 | def __init__(self, vision_tower, args, delay_load=False):
90 | super().__init__()
91 |
92 | self.is_loaded = False
93 |
94 | self.vision_tower_name = vision_tower
95 | self.select_layer = args.mm_vision_select_layer
96 | self.update_resolution = getattr(args, "mm_vision_resolution", 256)
97 | self.cfg_only = LKNetConfig.from_dict(unireplknet_l_plus_config)
98 |
99 | if not delay_load:
100 | self.load_model()
101 |
102 | def load_model(self):
103 | print(f"entering load model, load {self.vision_tower_name}")
104 | self.image_processor = CLIPImageProcessor.from_pretrained(
105 | os.path.dirname(self.vision_tower_name)
106 | )
107 | self.vision_tower = unireplknet_l_plus()
108 | self.vision_tower.config = LKNetConfig.from_dict(unireplknet_l_plus_config)
109 | ckpt = torch.load(self.vision_tower_name)
110 | del ckpt["norm.weight"]
111 | del ckpt["norm.bias"]
112 | missing_keys, unexpected_keys = self.vision_tower.load_state_dict(
113 | ckpt, strict=False
114 | )
115 | print("Loaded CLIP Pretrained Models")
116 | print(
117 | f"missing keys are {missing_keys}\n unexpected keys are {unexpected_keys}"
118 | )
119 |
120 | self.vision_tower.requires_grad_(False)
121 | self.is_loaded = True
122 |
123 | if self.update_resolution > 256:
124 | self.set_crop_size(self.update_resolution)
125 | print(
126 | f"Crop size changed to {self.update_resolution}x{self.update_resolution}"
127 | )
128 | self.make_layers_trainable_after_stage(4)
129 |
130 | def forward(self, images):
131 | if type(images) is list:
132 | image_features = []
133 | for image in images:
134 | x = image
135 | for stage_idx in range(5):
136 | x = self.vision_tower.downsample_layers[stage_idx](x)
137 | x = self.vision_tower.stages[stage_idx](x)
138 | image_features = x.permute(0, 2, 3, 1)
139 | image_features = image_features.reshape(x.shape[0], -1, x.shape[1]).to(image.dtype)
140 | image_features.append(x)
141 | else:
142 | x = images
143 | for stage_idx in range(5):
144 | x = self.vision_tower.downsample_layers[stage_idx](x)
145 | x = self.vision_tower.stages[stage_idx](x)
146 | image_features = x.permute(0, 2, 3, 1)
147 | image_features = image_features.reshape(x.shape[0], -1, x.shape[1]).to(images.dtype)
148 | return image_features
149 |
150 | def make_layers_trainable_after_stage(self, stage_index):
151 | for i, stage in enumerate(self.vision_tower.stages):
152 | if i >= stage_index:
153 | for param in stage.parameters():
154 | param.requires_grad = True
155 | for i, stage in enumerate(self.vision_tower.downsample_layers):
156 | if i >= stage_index:
157 | for param in stage.parameters():
158 | param.requires_grad = True
159 | self.print_trainable_parameters()
160 |
161 | def print_trainable_parameters(self):
162 | print("Trainable status of each stage:")
163 | for i, stage in enumerate(self.vision_tower.stages):
164 | trainable = all(param.requires_grad for param in stage.parameters())
165 | print(f"Stage {i}: {'Trainable' if trainable else 'Not Trainable'}")
166 |
167 | print("\nTrainable status of each downsampling layer:")
168 | for i, downsample_layer in enumerate(self.vision_tower.downsample_layers):
169 | trainable = all(param.requires_grad for param in downsample_layer.parameters())
170 | print(f"Downsampling Layer {i}: {'Trainable' if trainable else 'Not Trainable'}")
171 |
172 | def set_crop_size(self, new_size):
173 | size_dict = {"height": new_size, "width": new_size}
174 | self.image_processor.crop_size = size_dict
175 | self.image_processor.size = {"shortest_edge": new_size}
176 | self.vision_tower.config.image_size = new_size
177 | self.config.image_size = new_size
178 |
179 | def save_config(self, path):
180 | self.config.save_pretrained(path)
181 |
182 | @property
183 | def dummy_feature(self):
184 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
185 |
186 | @property
187 | def dtype(self):
188 | return self.vision_tower.downsample_layers[0][0].weight.dtype
189 |
190 | @property
191 | def device(self):
192 | return self.vision_tower.downsample_layers[0][0].weight.device
193 |
194 | @property
195 | def config(self):
196 | return self.cfg_only
197 |
198 | @property
199 | def hidden_size(self):
200 | return self.config.dims[-1]
201 |
202 | @property
203 | def num_patches_per_side(self):
204 | return self.config.image_size // 2 ** (len(self.config.depths) + 1)
205 |
206 | @property
207 | def num_patches(self):
208 | return (self.config.image_size // 2 ** (len(self.config.depths) + 1)) ** 2
209 |
210 | @property
211 | def crop_size(self):
212 | return self.image_processor.crop_size
213 |
214 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/siglip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import argparse
4 |
5 | from transformers import SiglipVisionConfig, SiglipVisionModel, SiglipImageProcessor
6 |
7 |
8 | class SiglipVisionTower(nn.Module):
9 | def __init__(self, vision_tower, args, delay_load=False):
10 | super().__init__()
11 |
12 | self.is_loaded = False
13 |
14 | self.vision_tower_name = vision_tower
15 | self.select_layer = args.mm_vision_select_layer
16 | self.select_feature = getattr(
17 | args, 'mm_vision_select_feature', 'patch')
18 |
19 | if not delay_load:
20 | self.load_model()
21 | else:
22 | self.cfg_only = SiglipVisionConfig.from_pretrained(
23 | self.vision_tower_name)
24 |
25 | def get_input_embeddings(self) -> nn.Module:
26 | return self.vision_tower.embeddings.patch_embedding
27 |
28 | def load_model(self):
29 | self.image_processor = SiglipImageProcessor.from_pretrained(
30 | self.vision_tower_name)
31 | self.vision_tower = SiglipVisionModel.from_pretrained(
32 | self.vision_tower_name)
33 | self.vision_tower.requires_grad_(False)
34 |
35 | self.is_loaded = True
36 |
37 | def print_trainable_parameters(self):
38 | for i, layer in enumerate(self.vision_tower.vision_model.encoder.layers):
39 | is_trainable = any(
40 | param.requires_grad for param in layer.parameters())
41 | print(f"ViT Layer {i} is trainable: {is_trainable}")
42 |
43 | def feature_select(self, image_forward_outs):
44 | image_features = image_forward_outs.hidden_states[self.select_layer]
45 | if self.select_feature == 'patch':
46 | image_features = image_features
47 | elif self.select_feature == 'cls_patch':
48 | image_features = image_features
49 | else:
50 | raise ValueError(
51 | f'Unexpected select feature: {self.select_feature}')
52 | return image_features
53 |
54 | def forward(self, images):
55 | if type(images) is list:
56 | image_features = []
57 | for image in images:
58 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
59 | output_hidden_states=True)
60 | image_feature = self.feature_select(
61 | image_forward_out).to(image.dtype)
62 | image_features.append(image_feature)
63 | else:
64 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
65 | output_hidden_states=True)
66 | image_features = self.feature_select(
67 | image_forward_outs).to(images.dtype)
68 |
69 | return image_features
70 |
71 | @property
72 | def dummy_feature(self):
73 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
74 |
75 | @property
76 | def dtype(self):
77 | return list(self.vision_tower.parameters())[0].dtype
78 |
79 | @property
80 | def device(self):
81 | return list(self.vision_tower.parameters())[0].device
82 |
83 | @property
84 | def config(self):
85 | if self.is_loaded:
86 | return self.vision_tower.config
87 | else:
88 | return self.cfg_only
89 |
90 | @property
91 | def hidden_size(self):
92 | return self.config.hidden_size
93 |
94 | @property
95 | def num_patches_per_side(self):
96 | return self.config.image_size // self.config.patch_size
97 |
98 | @property
99 | def num_patches(self):
100 | return (self.config.image_size // self.config.patch_size) ** 2
101 |
102 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/unireplknet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/model/multimodal_encoder/unireplknet/__init__.py
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 |
29 | def forward(self, x):
30 | x = self.pre_norm(x)
31 | return x + self.proj(x)
32 |
33 |
34 | def build_vision_projector(config, delay_load=False, **kwargs):
35 | projector_type = getattr(config, 'mm_projector_type', 'linear')
36 |
37 | if projector_type == 'linear':
38 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
39 |
40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41 | if mlp_gelu_match:
42 | mlp_depth = int(mlp_gelu_match.group(1))
43 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
44 | for _ in range(1, mlp_depth):
45 | modules.append(nn.GELU())
46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47 | return nn.Sequential(*modules)
48 |
49 | if projector_type == 'identity':
50 | return IdentityMap()
51 |
52 | raise ValueError(f'Unknown projector type: {projector_type}')
53 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if 'llama-2' in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "v1" in model_name.lower():
37 | conv_mode = "llava_v1"
38 | elif "mpt" in model_name.lower():
39 | conv_mode = "mpt"
40 | else:
41 | conv_mode = "llava_v0"
42 |
43 | if args.conv_mode is not None and conv_mode != args.conv_mode:
44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
45 | else:
46 | args.conv_mode = conv_mode
47 |
48 | conv = conv_templates[args.conv_mode].copy()
49 | if "mpt" in model_name.lower():
50 | roles = ('user', 'assistant')
51 | else:
52 | roles = conv.roles
53 |
54 | image = load_image(args.image_file)
55 | # Similar operation in model_worker.py
56 | image_tensor = process_images([image], image_processor, model.config)
57 | if type(image_tensor) is list:
58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
59 | else:
60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
61 |
62 | while True:
63 | try:
64 | inp = input(f"{roles[0]}: ")
65 | except EOFError:
66 | inp = ""
67 | if not inp:
68 | print("exit...")
69 | break
70 |
71 | print(f"{roles[1]}: ", end="")
72 |
73 | if image is not None:
74 | # first message
75 | if model.config.mm_use_im_start_end:
76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
77 | else:
78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
79 | conv.append_message(conv.roles[0], inp)
80 | image = None
81 | else:
82 | # later messages
83 | conv.append_message(conv.roles[0], inp)
84 | conv.append_message(conv.roles[1], None)
85 | prompt = conv.get_prompt()
86 |
87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89 | keywords = [stop_str]
90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92 |
93 | with torch.inference_mode():
94 | output_ids = model.generate(
95 | input_ids,
96 | images=image_tensor,
97 | do_sample=True if args.temperature > 0 else False,
98 | temperature=args.temperature,
99 | max_new_tokens=args.max_new_tokens,
100 | streamer=streamer,
101 | use_cache=True,
102 | stopping_criteria=[stopping_criteria])
103 |
104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
105 | conv.messages[-1][-1] = outputs
106 |
107 | if args.debug:
108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
114 | parser.add_argument("--model-base", type=str, default=None)
115 | parser.add_argument("--image-file", type=str, required=True)
116 | parser.add_argument("--device", type=str, default="cuda")
117 | parser.add_argument("--conv-mode", type=str, default=None)
118 | parser.add_argument("--temperature", type=float, default=0.2)
119 | parser.add_argument("--max-new-tokens", type=int, default=512)
120 | parser.add_argument("--load-8bit", action="store_true")
121 | parser.add_argument("--load-4bit", action="store_true")
122 | parser.add_argument("--debug", action="store_true")
123 | args = parser.parse_args()
124 | main(args)
125 |
--------------------------------------------------------------------------------
/llava/serve/controller.py:
--------------------------------------------------------------------------------
1 | """
2 | A controller manages distributed workers.
3 | It sends worker addresses to clients.
4 | """
5 | import argparse
6 | import asyncio
7 | import dataclasses
8 | from enum import Enum, auto
9 | import json
10 | import logging
11 | import time
12 | from typing import List, Union
13 | import threading
14 |
15 | from fastapi import FastAPI, Request
16 | from fastapi.responses import StreamingResponse
17 | import numpy as np
18 | import requests
19 | import uvicorn
20 |
21 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22 | from llava.utils import build_logger, server_error_msg
23 |
24 |
25 | logger = build_logger("controller", "controller.log")
26 |
27 |
28 | class DispatchMethod(Enum):
29 | LOTTERY = auto()
30 | SHORTEST_QUEUE = auto()
31 |
32 | @classmethod
33 | def from_str(cls, name):
34 | if name == "lottery":
35 | return cls.LOTTERY
36 | elif name == "shortest_queue":
37 | return cls.SHORTEST_QUEUE
38 | else:
39 | raise ValueError(f"Invalid dispatch method")
40 |
41 |
42 | @dataclasses.dataclass
43 | class WorkerInfo:
44 | model_names: List[str]
45 | speed: int
46 | queue_length: int
47 | check_heart_beat: bool
48 | last_heart_beat: str
49 |
50 |
51 | def heart_beat_controller(controller):
52 | while True:
53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54 | controller.remove_stable_workers_by_expiration()
55 |
56 |
57 | class Controller:
58 | def __init__(self, dispatch_method: str):
59 | # Dict[str -> WorkerInfo]
60 | self.worker_info = {}
61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62 |
63 | self.heart_beat_thread = threading.Thread(
64 | target=heart_beat_controller, args=(self,))
65 | self.heart_beat_thread.start()
66 |
67 | logger.info("Init controller")
68 |
69 | def register_worker(self, worker_name: str, check_heart_beat: bool,
70 | worker_status: dict):
71 | if worker_name not in self.worker_info:
72 | logger.info(f"Register a new worker: {worker_name}")
73 | else:
74 | logger.info(f"Register an existing worker: {worker_name}")
75 |
76 | if not worker_status:
77 | worker_status = self.get_worker_status(worker_name)
78 | if not worker_status:
79 | return False
80 |
81 | self.worker_info[worker_name] = WorkerInfo(
82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83 | check_heart_beat, time.time())
84 |
85 | logger.info(f"Register done: {worker_name}, {worker_status}")
86 | return True
87 |
88 | def get_worker_status(self, worker_name: str):
89 | try:
90 | r = requests.post(worker_name + "/worker_get_status", timeout=5)
91 | except requests.exceptions.RequestException as e:
92 | logger.error(f"Get status fails: {worker_name}, {e}")
93 | return None
94 |
95 | if r.status_code != 200:
96 | logger.error(f"Get status fails: {worker_name}, {r}")
97 | return None
98 |
99 | return r.json()
100 |
101 | def remove_worker(self, worker_name: str):
102 | del self.worker_info[worker_name]
103 |
104 | def refresh_all_workers(self):
105 | old_info = dict(self.worker_info)
106 | self.worker_info = {}
107 |
108 | for w_name, w_info in old_info.items():
109 | if not self.register_worker(w_name, w_info.check_heart_beat, None):
110 | logger.info(f"Remove stale worker: {w_name}")
111 |
112 | def list_models(self):
113 | model_names = set()
114 |
115 | for w_name, w_info in self.worker_info.items():
116 | model_names.update(w_info.model_names)
117 |
118 | return list(model_names)
119 |
120 | def get_worker_address(self, model_name: str):
121 | if self.dispatch_method == DispatchMethod.LOTTERY:
122 | worker_names = []
123 | worker_speeds = []
124 | for w_name, w_info in self.worker_info.items():
125 | if model_name in w_info.model_names:
126 | worker_names.append(w_name)
127 | worker_speeds.append(w_info.speed)
128 | worker_speeds = np.array(worker_speeds, dtype=np.float32)
129 | norm = np.sum(worker_speeds)
130 | if norm < 1e-4:
131 | return ""
132 | worker_speeds = worker_speeds / norm
133 | if True: # Directly return address
134 | pt = np.random.choice(np.arange(len(worker_names)),
135 | p=worker_speeds)
136 | worker_name = worker_names[pt]
137 | return worker_name
138 |
139 | # Check status before returning
140 | while True:
141 | pt = np.random.choice(np.arange(len(worker_names)),
142 | p=worker_speeds)
143 | worker_name = worker_names[pt]
144 |
145 | if self.get_worker_status(worker_name):
146 | break
147 | else:
148 | self.remove_worker(worker_name)
149 | worker_speeds[pt] = 0
150 | norm = np.sum(worker_speeds)
151 | if norm < 1e-4:
152 | return ""
153 | worker_speeds = worker_speeds / norm
154 | continue
155 | return worker_name
156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157 | worker_names = []
158 | worker_qlen = []
159 | for w_name, w_info in self.worker_info.items():
160 | if model_name in w_info.model_names:
161 | worker_names.append(w_name)
162 | worker_qlen.append(w_info.queue_length / w_info.speed)
163 | if len(worker_names) == 0:
164 | return ""
165 | min_index = np.argmin(worker_qlen)
166 | w_name = worker_names[min_index]
167 | self.worker_info[w_name].queue_length += 1
168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169 | return w_name
170 | else:
171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172 |
173 | def receive_heart_beat(self, worker_name: str, queue_length: int):
174 | if worker_name not in self.worker_info:
175 | logger.info(f"Receive unknown heart beat. {worker_name}")
176 | return False
177 |
178 | self.worker_info[worker_name].queue_length = queue_length
179 | self.worker_info[worker_name].last_heart_beat = time.time()
180 | logger.info(f"Receive heart beat. {worker_name}")
181 | return True
182 |
183 | def remove_stable_workers_by_expiration(self):
184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185 | to_delete = []
186 | for worker_name, w_info in self.worker_info.items():
187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188 | to_delete.append(worker_name)
189 |
190 | for worker_name in to_delete:
191 | self.remove_worker(worker_name)
192 |
193 | def worker_api_generate_stream(self, params):
194 | worker_addr = self.get_worker_address(params["model"])
195 | if not worker_addr:
196 | logger.info(f"no worker: {params['model']}")
197 | ret = {
198 | "text": server_error_msg,
199 | "error_code": 2,
200 | }
201 | yield json.dumps(ret).encode() + b"\0"
202 |
203 | try:
204 | response = requests.post(worker_addr + "/worker_generate_stream",
205 | json=params, stream=True, timeout=5)
206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207 | if chunk:
208 | yield chunk + b"\0"
209 | except requests.exceptions.RequestException as e:
210 | logger.info(f"worker timeout: {worker_addr}")
211 | ret = {
212 | "text": server_error_msg,
213 | "error_code": 3,
214 | }
215 | yield json.dumps(ret).encode() + b"\0"
216 |
217 |
218 | # Let the controller act as a worker to achieve hierarchical
219 | # management. This can be used to connect isolated sub networks.
220 | def worker_api_get_status(self):
221 | model_names = set()
222 | speed = 0
223 | queue_length = 0
224 |
225 | for w_name in self.worker_info:
226 | worker_status = self.get_worker_status(w_name)
227 | if worker_status is not None:
228 | model_names.update(worker_status["model_names"])
229 | speed += worker_status["speed"]
230 | queue_length += worker_status["queue_length"]
231 |
232 | return {
233 | "model_names": list(model_names),
234 | "speed": speed,
235 | "queue_length": queue_length,
236 | }
237 |
238 |
239 | app = FastAPI()
240 |
241 |
242 | @app.post("/register_worker")
243 | async def register_worker(request: Request):
244 | data = await request.json()
245 | controller.register_worker(
246 | data["worker_name"], data["check_heart_beat"],
247 | data.get("worker_status", None))
248 |
249 |
250 | @app.post("/refresh_all_workers")
251 | async def refresh_all_workers():
252 | models = controller.refresh_all_workers()
253 |
254 |
255 | @app.post("/list_models")
256 | async def list_models():
257 | models = controller.list_models()
258 | return {"models": models}
259 |
260 |
261 | @app.post("/get_worker_address")
262 | async def get_worker_address(request: Request):
263 | data = await request.json()
264 | addr = controller.get_worker_address(data["model"])
265 | return {"address": addr}
266 |
267 |
268 | @app.post("/receive_heart_beat")
269 | async def receive_heart_beat(request: Request):
270 | data = await request.json()
271 | exist = controller.receive_heart_beat(
272 | data["worker_name"], data["queue_length"])
273 | return {"exist": exist}
274 |
275 |
276 | @app.post("/worker_generate_stream")
277 | async def worker_api_generate_stream(request: Request):
278 | params = await request.json()
279 | generator = controller.worker_api_generate_stream(params)
280 | return StreamingResponse(generator)
281 |
282 |
283 | @app.post("/worker_get_status")
284 | async def worker_api_get_status(request: Request):
285 | return controller.worker_api_get_status()
286 |
287 |
288 | if __name__ == "__main__":
289 | parser = argparse.ArgumentParser()
290 | parser.add_argument("--host", type=str, default="localhost")
291 | parser.add_argument("--port", type=int, default=21001)
292 | parser.add_argument("--dispatch-method", type=str, choices=[
293 | "lottery", "shortest_queue"], default="shortest_queue")
294 | args = parser.parse_args()
295 | logger.info(f"args: {args}")
296 |
297 | controller = Controller(args.dispatch_method)
298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
299 |
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/conv-llava/3e7ae574c9c2cccaac0d9c1d12fed02f1a2d11bc/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/gradio_web_server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import time
6 |
7 | import gradio as gr
8 | import requests
9 |
10 | from llava.conversation import (default_conversation, conv_templates,
11 | SeparatorStyle)
12 | from llava.constants import LOGDIR
13 | from llava.utils import (build_logger, server_error_msg,
14 | violates_moderation, moderation_msg)
15 | import hashlib
16 |
17 |
18 | logger = build_logger("gradio_web_server", "gradio_web_server.log")
19 |
20 | headers = {"User-Agent": "LLaVA Client"}
21 |
22 | no_change_btn = gr.Button.update()
23 | enable_btn = gr.Button.update(interactive=True)
24 | disable_btn = gr.Button.update(interactive=False)
25 |
26 | priority = {
27 | "vicuna-13b": "aaaaaaa",
28 | "koala-13b": "aaaaaab",
29 | }
30 |
31 |
32 | def get_conv_log_filename():
33 | t = datetime.datetime.now()
34 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
35 | return name
36 |
37 |
38 | def get_model_list():
39 | ret = requests.post(args.controller_url + "/refresh_all_workers")
40 | assert ret.status_code == 200
41 | ret = requests.post(args.controller_url + "/list_models")
42 | models = ret.json()["models"]
43 | models.sort(key=lambda x: priority.get(x, x))
44 | logger.info(f"Models: {models}")
45 | return models
46 |
47 |
48 | get_window_url_params = """
49 | function() {
50 | const params = new URLSearchParams(window.location.search);
51 | url_params = Object.fromEntries(params);
52 | console.log(url_params);
53 | return url_params;
54 | }
55 | """
56 |
57 |
58 | def load_demo(url_params, request: gr.Request):
59 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
60 |
61 | dropdown_update = gr.Dropdown.update(visible=True)
62 | if "model" in url_params:
63 | model = url_params["model"]
64 | if model in models:
65 | dropdown_update = gr.Dropdown.update(
66 | value=model, visible=True)
67 |
68 | state = default_conversation.copy()
69 | return state, dropdown_update
70 |
71 |
72 | def load_demo_refresh_model_list(request: gr.Request):
73 | logger.info(f"load_demo. ip: {request.client.host}")
74 | models = get_model_list()
75 | state = default_conversation.copy()
76 | dropdown_update = gr.Dropdown.update(
77 | choices=models,
78 | value=models[0] if len(models) > 0 else ""
79 | )
80 | return state, dropdown_update
81 |
82 |
83 | def vote_last_response(state, vote_type, model_selector, request: gr.Request):
84 | with open(get_conv_log_filename(), "a") as fout:
85 | data = {
86 | "tstamp": round(time.time(), 4),
87 | "type": vote_type,
88 | "model": model_selector,
89 | "state": state.dict(),
90 | "ip": request.client.host,
91 | }
92 | fout.write(json.dumps(data) + "\n")
93 |
94 |
95 | def upvote_last_response(state, model_selector, request: gr.Request):
96 | logger.info(f"upvote. ip: {request.client.host}")
97 | vote_last_response(state, "upvote", model_selector, request)
98 | return ("",) + (disable_btn,) * 3
99 |
100 |
101 | def downvote_last_response(state, model_selector, request: gr.Request):
102 | logger.info(f"downvote. ip: {request.client.host}")
103 | vote_last_response(state, "downvote", model_selector, request)
104 | return ("",) + (disable_btn,) * 3
105 |
106 |
107 | def flag_last_response(state, model_selector, request: gr.Request):
108 | logger.info(f"flag. ip: {request.client.host}")
109 | vote_last_response(state, "flag", model_selector, request)
110 | return ("",) + (disable_btn,) * 3
111 |
112 |
113 | def regenerate(state, image_process_mode, request: gr.Request):
114 | logger.info(f"regenerate. ip: {request.client.host}")
115 | state.messages[-1][-1] = None
116 | prev_human_msg = state.messages[-2]
117 | if type(prev_human_msg[1]) in (tuple, list):
118 | prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
119 | state.skip_next = False
120 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
121 |
122 |
123 | def clear_history(request: gr.Request):
124 | logger.info(f"clear_history. ip: {request.client.host}")
125 | state = default_conversation.copy()
126 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
127 |
128 |
129 | def add_text(state, text, image, image_process_mode, request: gr.Request):
130 | logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
131 | if len(text) <= 0 and image is None:
132 | state.skip_next = True
133 | return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
134 | if args.moderate:
135 | flagged = violates_moderation(text)
136 | if flagged:
137 | state.skip_next = True
138 | return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
139 | no_change_btn,) * 5
140 |
141 | text = text[:1536] # Hard cut-off
142 | if image is not None:
143 | text = text[:1200] # Hard cut-off for images
144 | if '' not in text:
145 | # text = '' + text
146 | text = text + '\n'
147 | text = (text, image, image_process_mode)
148 | if len(state.get_images(return_pil=True)) > 0:
149 | state = default_conversation.copy()
150 | state.append_message(state.roles[0], text)
151 | state.append_message(state.roles[1], None)
152 | state.skip_next = False
153 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
154 |
155 |
156 | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
157 | logger.info(f"http_bot. ip: {request.client.host}")
158 | start_tstamp = time.time()
159 | model_name = model_selector
160 |
161 | if state.skip_next:
162 | # This generate call is skipped due to invalid inputs
163 | yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
164 | return
165 |
166 | if len(state.messages) == state.offset + 2:
167 | # First round of conversation
168 | if "llava" in model_name.lower():
169 | if 'llama-2' in model_name.lower():
170 | template_name = "llava_llama_2"
171 | elif "v1" in model_name.lower():
172 | if 'mmtag' in model_name.lower():
173 | template_name = "v1_mmtag"
174 | elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
175 | template_name = "v1_mmtag"
176 | else:
177 | template_name = "llava_v1"
178 | elif "mpt" in model_name.lower():
179 | template_name = "mpt"
180 | else:
181 | if 'mmtag' in model_name.lower():
182 | template_name = "v0_mmtag"
183 | elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower():
184 | template_name = "v0_mmtag"
185 | else:
186 | template_name = "llava_v0"
187 | elif "mpt" in model_name:
188 | template_name = "mpt_text"
189 | elif "llama-2" in model_name:
190 | template_name = "llama_2"
191 | else:
192 | template_name = "vicuna_v1"
193 | new_state = conv_templates[template_name].copy()
194 | new_state.append_message(new_state.roles[0], state.messages[-2][1])
195 | new_state.append_message(new_state.roles[1], None)
196 | state = new_state
197 |
198 | # Query worker address
199 | controller_url = args.controller_url
200 | ret = requests.post(controller_url + "/get_worker_address",
201 | json={"model": model_name})
202 | worker_addr = ret.json()["address"]
203 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
204 |
205 | # No available worker
206 | if worker_addr == "":
207 | state.messages[-1][-1] = server_error_msg
208 | yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
209 | return
210 |
211 | # Construct prompt
212 | prompt = state.get_prompt()
213 |
214 | all_images = state.get_images(return_pil=True)
215 | all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
216 | for image, hash in zip(all_images, all_image_hash):
217 | t = datetime.datetime.now()
218 | filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
219 | if not os.path.isfile(filename):
220 | os.makedirs(os.path.dirname(filename), exist_ok=True)
221 | image.save(filename)
222 |
223 | # Make requests
224 | pload = {
225 | "model": model_name,
226 | "prompt": prompt,
227 | "temperature": float(temperature),
228 | "top_p": float(top_p),
229 | "max_new_tokens": min(int(max_new_tokens), 1536),
230 | "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
231 | "images": f'List of {len(state.get_images())} images: {all_image_hash}',
232 | }
233 | logger.info(f"==== request ====\n{pload}")
234 |
235 | pload['images'] = state.get_images()
236 |
237 | state.messages[-1][-1] = "▌"
238 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
239 |
240 | try:
241 | # Stream output
242 | response = requests.post(worker_addr + "/worker_generate_stream",
243 | headers=headers, json=pload, stream=True, timeout=10)
244 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
245 | if chunk:
246 | data = json.loads(chunk.decode())
247 | if data["error_code"] == 0:
248 | output = data["text"][len(prompt):].strip()
249 | state.messages[-1][-1] = output + "▌"
250 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
251 | else:
252 | output = data["text"] + f" (error_code: {data['error_code']})"
253 | state.messages[-1][-1] = output
254 | yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
255 | return
256 | time.sleep(0.03)
257 | except requests.exceptions.RequestException as e:
258 | state.messages[-1][-1] = server_error_msg
259 | yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
260 | return
261 |
262 | state.messages[-1][-1] = state.messages[-1][-1][:-1]
263 | yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
264 |
265 | finish_tstamp = time.time()
266 | logger.info(f"{output}")
267 |
268 | with open(get_conv_log_filename(), "a") as fout:
269 | data = {
270 | "tstamp": round(finish_tstamp, 4),
271 | "type": "chat",
272 | "model": model_name,
273 | "start": round(start_tstamp, 4),
274 | "finish": round(finish_tstamp, 4),
275 | "state": state.dict(),
276 | "images": all_image_hash,
277 | "ip": request.client.host,
278 | }
279 | fout.write(json.dumps(data) + "\n")
280 |
281 | title_markdown = ("""
282 | # 🌋 LLaVA: Large Language and Vision Assistant
283 | [[Project Page](https://llava-vl.github.io)] [[Code](https://github.com/haotian-liu/LLaVA)] [[Model](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)] | 📚 [[LLaVA](https://arxiv.org/abs/2304.08485)] [[LLaVA-v1.5](https://arxiv.org/abs/2310.03744)]
284 | """)
285 |
286 | tos_markdown = ("""
287 | ### Terms of use
288 | By using this service, users are required to agree to the following terms:
289 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
290 | Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
291 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
292 | """)
293 |
294 |
295 | learn_more_markdown = ("""
296 | ### License
297 | The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
298 | """)
299 |
300 | block_css = """
301 |
302 | #buttons button {
303 | min-width: min(120px,100%);
304 | }
305 |
306 | """
307 |
308 | def build_demo(embed_mode):
309 | textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
310 | with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
311 | state = gr.State()
312 |
313 | if not embed_mode:
314 | gr.Markdown(title_markdown)
315 |
316 | with gr.Row():
317 | with gr.Column(scale=3):
318 | with gr.Row(elem_id="model_selector_row"):
319 | model_selector = gr.Dropdown(
320 | choices=models,
321 | value=models[0] if len(models) > 0 else "",
322 | interactive=True,
323 | show_label=False,
324 | container=False)
325 |
326 | imagebox = gr.Image(type="pil")
327 | image_process_mode = gr.Radio(
328 | ["Crop", "Resize", "Pad", "Default"],
329 | value="Default",
330 | label="Preprocess for non-square image", visible=False)
331 |
332 | cur_dir = os.path.dirname(os.path.abspath(__file__))
333 | gr.Examples(examples=[
334 | [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
335 | [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
336 | ], inputs=[imagebox, textbox])
337 |
338 | with gr.Accordion("Parameters", open=False) as parameter_row:
339 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
340 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
341 | max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
342 |
343 | with gr.Column(scale=8):
344 | chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA Chatbot", height=550)
345 | with gr.Row():
346 | with gr.Column(scale=8):
347 | textbox.render()
348 | with gr.Column(scale=1, min_width=50):
349 | submit_btn = gr.Button(value="Send", variant="primary")
350 | with gr.Row(elem_id="buttons") as button_row:
351 | upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
352 | downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
353 | flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
354 | #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
355 | regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
356 | clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
357 |
358 | if not embed_mode:
359 | gr.Markdown(tos_markdown)
360 | gr.Markdown(learn_more_markdown)
361 | url_params = gr.JSON(visible=False)
362 |
363 | # Register listeners
364 | btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
365 | upvote_btn.click(
366 | upvote_last_response,
367 | [state, model_selector],
368 | [textbox, upvote_btn, downvote_btn, flag_btn],
369 | queue=False
370 | )
371 | downvote_btn.click(
372 | downvote_last_response,
373 | [state, model_selector],
374 | [textbox, upvote_btn, downvote_btn, flag_btn],
375 | queue=False
376 | )
377 | flag_btn.click(
378 | flag_last_response,
379 | [state, model_selector],
380 | [textbox, upvote_btn, downvote_btn, flag_btn],
381 | queue=False
382 | )
383 |
384 | regenerate_btn.click(
385 | regenerate,
386 | [state, image_process_mode],
387 | [state, chatbot, textbox, imagebox] + btn_list,
388 | queue=False
389 | ).then(
390 | http_bot,
391 | [state, model_selector, temperature, top_p, max_output_tokens],
392 | [state, chatbot] + btn_list
393 | )
394 |
395 | clear_btn.click(
396 | clear_history,
397 | None,
398 | [state, chatbot, textbox, imagebox] + btn_list,
399 | queue=False
400 | )
401 |
402 | textbox.submit(
403 | add_text,
404 | [state, textbox, imagebox, image_process_mode],
405 | [state, chatbot, textbox, imagebox] + btn_list,
406 | queue=False
407 | ).then(
408 | http_bot,
409 | [state, model_selector, temperature, top_p, max_output_tokens],
410 | [state, chatbot] + btn_list
411 | )
412 |
413 | submit_btn.click(
414 | add_text,
415 | [state, textbox, imagebox, image_process_mode],
416 | [state, chatbot, textbox, imagebox] + btn_list,
417 | queue=False
418 | ).then(
419 | http_bot,
420 | [state, model_selector, temperature, top_p, max_output_tokens],
421 | [state, chatbot] + btn_list
422 | )
423 |
424 | if args.model_list_mode == "once":
425 | demo.load(
426 | load_demo,
427 | [url_params],
428 | [state, model_selector],
429 | _js=get_window_url_params,
430 | queue=False
431 | )
432 | elif args.model_list_mode == "reload":
433 | demo.load(
434 | load_demo_refresh_model_list,
435 | None,
436 | [state, model_selector],
437 | queue=False
438 | )
439 | else:
440 | raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
441 |
442 | return demo
443 |
444 |
445 | if __name__ == "__main__":
446 | parser = argparse.ArgumentParser()
447 | parser.add_argument("--host", type=str, default="0.0.0.0")
448 | parser.add_argument("--port", type=int)
449 | parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
450 | parser.add_argument("--concurrency-count", type=int, default=10)
451 | parser.add_argument("--model-list-mode", type=str, default="once",
452 | choices=["once", "reload"])
453 | parser.add_argument("--share", action="store_true")
454 | parser.add_argument("--moderate", action="store_true")
455 | parser.add_argument("--embed", action="store_true")
456 | args = parser.parse_args()
457 | logger.info(f"args: {args}")
458 |
459 | models = get_model_list()
460 |
461 | logger.info(args)
462 | demo = build_demo(args.embed)
463 | demo.queue(
464 | concurrency_count=args.concurrency_count,
465 | api_open=False
466 | ).launch(
467 | server_name=args.host,
468 | server_port=args.port,
469 | share=args.share
470 | )
471 |
--------------------------------------------------------------------------------
/llava/serve/model_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker executes the model.
3 | """
4 | import argparse
5 | import asyncio
6 | import json
7 | import time
8 | import threading
9 | import uuid
10 |
11 | from fastapi import FastAPI, Request, BackgroundTasks
12 | from fastapi.responses import StreamingResponse
13 | import requests
14 | import torch
15 | import uvicorn
16 | from functools import partial
17 |
18 | from llava.constants import WORKER_HEART_BEAT_INTERVAL
19 | from llava.utils import (build_logger, server_error_msg,
20 | pretty_print_semaphore)
21 | from llava.model.builder import load_pretrained_model
22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 | from transformers import TextIteratorStreamer
25 | from threading import Thread
26 |
27 |
28 | GB = 1 << 30
29 |
30 | worker_id = str(uuid.uuid4())[:6]
31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32 | global_counter = 0
33 |
34 | model_semaphore = None
35 |
36 |
37 | def heart_beat_worker(controller):
38 |
39 | while True:
40 | time.sleep(WORKER_HEART_BEAT_INTERVAL)
41 | controller.send_heart_beat()
42 |
43 |
44 | class ModelWorker:
45 | def __init__(self, controller_addr, worker_addr,
46 | worker_id, no_register,
47 | model_path, model_base, model_name,
48 | load_8bit, load_4bit, device):
49 | self.controller_addr = controller_addr
50 | self.worker_addr = worker_addr
51 | self.worker_id = worker_id
52 | if model_path.endswith("/"):
53 | model_path = model_path[:-1]
54 | if model_name is None:
55 | model_paths = model_path.split("/")
56 | if model_paths[-1].startswith('checkpoint-'):
57 | self.model_name = model_paths[-2] + "_" + model_paths[-1]
58 | else:
59 | self.model_name = model_paths[-1]
60 | else:
61 | self.model_name = model_name
62 |
63 | self.device = device
64 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
67 | self.is_multimodal = 'llava' in self.model_name.lower()
68 |
69 | if not no_register:
70 | self.register_to_controller()
71 | self.heart_beat_thread = threading.Thread(
72 | target=heart_beat_worker, args=(self,))
73 | self.heart_beat_thread.start()
74 |
75 | def register_to_controller(self):
76 | logger.info("Register to controller")
77 |
78 | url = self.controller_addr + "/register_worker"
79 | data = {
80 | "worker_name": self.worker_addr,
81 | "check_heart_beat": True,
82 | "worker_status": self.get_status()
83 | }
84 | r = requests.post(url, json=data)
85 | assert r.status_code == 200
86 |
87 | def send_heart_beat(self):
88 | logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90 | f"global_counter: {global_counter}")
91 |
92 | url = self.controller_addr + "/receive_heart_beat"
93 |
94 | while True:
95 | try:
96 | ret = requests.post(url, json={
97 | "worker_name": self.worker_addr,
98 | "queue_length": self.get_queue_length()}, timeout=5)
99 | exist = ret.json()["exist"]
100 | break
101 | except requests.exceptions.RequestException as e:
102 | logger.error(f"heart beat error: {e}")
103 | time.sleep(5)
104 |
105 | if not exist:
106 | self.register_to_controller()
107 |
108 | def get_queue_length(self):
109 | if model_semaphore is None:
110 | return 0
111 | else:
112 | return args.limit_model_concurrency - model_semaphore._value + (len(
113 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114 |
115 | def get_status(self):
116 | return {
117 | "model_names": [self.model_name],
118 | "speed": 1,
119 | "queue_length": self.get_queue_length(),
120 | }
121 |
122 | @torch.inference_mode()
123 | def generate_stream(self, params):
124 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125 |
126 | prompt = params["prompt"]
127 | ori_prompt = prompt
128 | images = params.get("images", None)
129 | num_image_tokens = 0
130 | if images is not None and len(images) > 0 and self.is_multimodal:
131 | if len(images) > 0:
132 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133 | raise ValueError("Number of images does not match number of tokens in prompt")
134 |
135 | images = [load_image_from_base64(image) for image in images]
136 | images = process_images(images, image_processor, model.config)
137 |
138 | if type(images) is list:
139 | images = [image.to(self.model.device, dtype=torch.float16) for image in images]
140 | else:
141 | images = images.to(self.model.device, dtype=torch.float16)
142 |
143 | replace_token = DEFAULT_IMAGE_TOKEN
144 | if getattr(self.model.config, 'mm_use_im_start_end', False):
145 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147 |
148 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
149 | else:
150 | images = None
151 | image_args = {"images": images}
152 | else:
153 | images = None
154 | image_args = {}
155 |
156 | temperature = float(params.get("temperature", 1.0))
157 | top_p = float(params.get("top_p", 1.0))
158 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
159 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
160 | stop_str = params.get("stop", None)
161 | do_sample = True if temperature > 0.001 else False
162 |
163 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
164 | keywords = [stop_str]
165 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
166 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
167 |
168 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
169 |
170 | if max_new_tokens < 1:
171 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
172 | return
173 |
174 | thread = Thread(target=model.generate, kwargs=dict(
175 | inputs=input_ids,
176 | do_sample=do_sample,
177 | temperature=temperature,
178 | top_p=top_p,
179 | max_new_tokens=max_new_tokens,
180 | streamer=streamer,
181 | stopping_criteria=[stopping_criteria],
182 | use_cache=True,
183 | **image_args
184 | ))
185 | thread.start()
186 |
187 | generated_text = ori_prompt
188 | for new_text in streamer:
189 | generated_text += new_text
190 | if generated_text.endswith(stop_str):
191 | generated_text = generated_text[:-len(stop_str)]
192 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
193 |
194 | def generate_stream_gate(self, params):
195 | try:
196 | for x in self.generate_stream(params):
197 | yield x
198 | except ValueError as e:
199 | print("Caught ValueError:", e)
200 | ret = {
201 | "text": server_error_msg,
202 | "error_code": 1,
203 | }
204 | yield json.dumps(ret).encode() + b"\0"
205 | except torch.cuda.CudaError as e:
206 | print("Caught torch.cuda.CudaError:", e)
207 | ret = {
208 | "text": server_error_msg,
209 | "error_code": 1,
210 | }
211 | yield json.dumps(ret).encode() + b"\0"
212 | except Exception as e:
213 | print("Caught Unknown Error", e)
214 | ret = {
215 | "text": server_error_msg,
216 | "error_code": 1,
217 | }
218 | yield json.dumps(ret).encode() + b"\0"
219 |
220 |
221 | app = FastAPI()
222 |
223 |
224 | def release_model_semaphore(fn=None):
225 | model_semaphore.release()
226 | if fn is not None:
227 | fn()
228 |
229 |
230 | @app.post("/worker_generate_stream")
231 | async def generate_stream(request: Request):
232 | global model_semaphore, global_counter
233 | global_counter += 1
234 | params = await request.json()
235 |
236 | if model_semaphore is None:
237 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
238 | await model_semaphore.acquire()
239 | worker.send_heart_beat()
240 | generator = worker.generate_stream_gate(params)
241 | background_tasks = BackgroundTasks()
242 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
243 | return StreamingResponse(generator, background=background_tasks)
244 |
245 |
246 | @app.post("/worker_get_status")
247 | async def get_status(request: Request):
248 | return worker.get_status()
249 |
250 |
251 | if __name__ == "__main__":
252 | parser = argparse.ArgumentParser()
253 | parser.add_argument("--host", type=str, default="localhost")
254 | parser.add_argument("--port", type=int, default=21002)
255 | parser.add_argument("--worker-address", type=str,
256 | default="http://localhost:21002")
257 | parser.add_argument("--controller-address", type=str,
258 | default="http://localhost:21001")
259 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
260 | parser.add_argument("--model-base", type=str, default=None)
261 | parser.add_argument("--model-name", type=str)
262 | parser.add_argument("--device", type=str, default="cuda")
263 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
264 | parser.add_argument("--limit-model-concurrency", type=int, default=5)
265 | parser.add_argument("--stream-interval", type=int, default=1)
266 | parser.add_argument("--no-register", action="store_true")
267 | parser.add_argument("--load-8bit", action="store_true")
268 | parser.add_argument("--load-4bit", action="store_true")
269 | args = parser.parse_args()
270 | logger.info(f"args: {args}")
271 |
272 | if args.multi_modal:
273 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
274 |
275 | worker = ModelWorker(args.controller_address,
276 | args.worker_address,
277 | worker_id,
278 | args.no_register,
279 | args.model_path,
280 | args.model_base,
281 | args.model_name,
282 | args.load_8bit,
283 | args.load_4bit,
284 | args.device)
285 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
286 |
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/train/llava_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch.utils.data import Sampler
6 |
7 | from transformers import Trainer
8 | from transformers.trainer import (
9 | is_sagemaker_mp_enabled,
10 | get_parameter_names,
11 | has_length,
12 | ALL_LAYERNORM_LAYERS,
13 | logger,
14 | )
15 | from typing import List, Optional
16 |
17 |
18 | def maybe_zero_3(param, ignore_status=False, name=None):
19 | from deepspeed import zero
20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
21 | if hasattr(param, "ds_id"):
22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
23 | if not ignore_status:
24 | print(name, 'no ignore status')
25 | with zero.GatheredParameters([param]):
26 | param = param.data.detach().cpu().clone()
27 | else:
28 | param = param.detach().cpu().clone()
29 | return param
30 |
31 |
32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
33 | to_return = {k: t for k, t in named_params if any(
34 | key_match in k for key_match in keys_to_match)}
35 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu()
36 | for k, v in to_return.items()}
37 | return to_return
38 |
39 |
40 | def split_to_even_chunks(indices, lengths, num_chunks):
41 | """
42 | Split a list of indices into `chunks` chunks of roughly equal lengths.
43 | """
44 |
45 | if len(indices) % num_chunks != 0:
46 | return [indices[i::num_chunks] for i in range(num_chunks)]
47 |
48 | num_indices_per_chunk = len(indices) // num_chunks
49 |
50 | chunks = [[] for _ in range(num_chunks)]
51 | chunks_lengths = [0 for _ in range(num_chunks)]
52 | for index in indices:
53 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
54 | chunks[shortest_chunk].append(index)
55 | chunks_lengths[shortest_chunk] += lengths[index]
56 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
57 | chunks_lengths[shortest_chunk] = float("inf")
58 |
59 | return chunks
60 |
61 |
62 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
63 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
64 | assert all(l != 0 for l in lengths), "Should not have zero length."
65 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
66 | # all samples are in the same modality
67 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
68 | mm_indices, mm_lengths = zip(*[(i, l)
69 | for i, l in enumerate(lengths) if l > 0])
70 | lang_indices, lang_lengths = zip(
71 | *[(i, -l) for i, l in enumerate(lengths) if l < 0])
72 |
73 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(
74 | mm_lengths, batch_size, world_size, generator=None)]
75 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(
76 | lang_lengths, batch_size, world_size, generator=None)]
77 | megabatch_size = world_size * batch_size
78 | mm_megabatches = [mm_shuffle[i: i + megabatch_size]
79 | for i in range(0, len(mm_shuffle), megabatch_size)]
80 | lang_megabatches = [lang_shuffle[i: i + megabatch_size]
81 | for i in range(0, len(lang_shuffle), megabatch_size)]
82 |
83 | last_mm = mm_megabatches[-1]
84 | last_lang = lang_megabatches[-1]
85 | additional_batch = last_mm + last_lang
86 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
87 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
88 | megabatches = [megabatches[i] for i in megabatch_indices]
89 |
90 | if len(additional_batch) > 0:
91 | megabatches.append(sorted(additional_batch))
92 |
93 | return [i for megabatch in megabatches for i in megabatch]
94 |
95 |
96 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
97 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
98 | indices = torch.randperm(len(lengths), generator=generator)
99 | megabatch_size = world_size * batch_size
100 | megabatches = [indices[i: i + megabatch_size].tolist()
101 | for i in range(0, len(lengths), megabatch_size)]
102 | megabatches = [sorted(megabatch, key=lambda i: lengths[i],
103 | reverse=True) for megabatch in megabatches]
104 | megabatches = [split_to_even_chunks(
105 | megabatch, lengths, world_size) for megabatch in megabatches]
106 |
107 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
108 |
109 |
110 | class LengthGroupedSampler(Sampler):
111 | r"""
112 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
113 | keeping a bit of randomness.
114 | """
115 |
116 | def __init__(
117 | self,
118 | batch_size: int,
119 | world_size: int,
120 | lengths: Optional[List[int]] = None,
121 | generator=None,
122 | group_by_modality: bool = False,
123 | ):
124 | if lengths is None:
125 | raise ValueError("Lengths must be provided.")
126 |
127 | self.batch_size = batch_size
128 | self.world_size = world_size
129 | self.lengths = lengths
130 | self.generator = generator
131 | self.group_by_modality = group_by_modality
132 |
133 | def __len__(self):
134 | return len(self.lengths)
135 |
136 | def __iter__(self):
137 | if self.group_by_modality:
138 | indices = get_modality_length_grouped_indices(
139 | self.lengths, self.batch_size, self.world_size, generator=self.generator)
140 | else:
141 | indices = get_length_grouped_indices(
142 | self.lengths, self.batch_size, self.world_size, generator=self.generator)
143 | return iter(indices)
144 |
145 |
146 | class VisionLLaVATrainer(Trainer):
147 |
148 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
149 | if self.train_dataset is None or not has_length(self.train_dataset):
150 | return None
151 |
152 | if self.args.group_by_modality_length:
153 | lengths = self.train_dataset.modality_lengths
154 | return LengthGroupedSampler(
155 | self.args.train_batch_size,
156 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
157 | lengths=lengths,
158 | group_by_modality=True,
159 | )
160 | else:
161 | return super()._get_train_sampler()
162 |
163 | def create_optimizer(self):
164 | """
165 | Setup the optimizer.
166 |
167 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
168 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
169 | """
170 | if is_sagemaker_mp_enabled():
171 | return super().create_optimizer()
172 |
173 | opt_model = self.model
174 |
175 | if self.optimizer is None:
176 | decay_parameters = get_parameter_names(
177 | opt_model, ALL_LAYERNORM_LAYERS)
178 | decay_parameters = [
179 | name for name in decay_parameters if "bias" not in name]
180 | if self.args.mm_projector_lr is not None and self.args.mm_projector_lr != 0:
181 | projector_parameters = [
182 | name for name, _ in opt_model.named_parameters() if "mm_projector" in name]
183 | if self.args.vision_tower_lr is not None and self.args.vision_tower_lr != 0:
184 | vision_tower_parameters = [
185 | name for name, _ in opt_model.named_parameters() if "vision_tower" in name]
186 | optimizer_grouped_parameters = [
187 | {
188 | "params": [
189 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
190 | ],
191 | "weight_decay": self.args.weight_decay,
192 | },
193 | {
194 | "params": [
195 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
196 | ],
197 | "weight_decay": self.args.weight_decay,
198 | "lr": self.args.vision_tower_lr,
199 | },
200 | {
201 | "params": [
202 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n not in vision_tower_parameters and p.requires_grad)
203 | ],
204 | "weight_decay": 0.0,
205 | },
206 | {
207 | "params": [
208 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and n in vision_tower_parameters and p.requires_grad)
209 | ],
210 | "weight_decay": 0.0,
211 | "lr": self.args.vision_tower_lr,
212 | },
213 | {
214 | "params": [
215 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
216 | ],
217 | "weight_decay": self.args.weight_decay,
218 | "lr": self.args.mm_projector_lr,
219 | },
220 | {
221 | "params": [
222 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
223 | ],
224 | "weight_decay": 0.0,
225 | "lr": self.args.mm_projector_lr,
226 | },
227 | ]
228 | else:
229 | optimizer_grouped_parameters = [
230 | {
231 | "params": [
232 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
233 | ],
234 | "weight_decay": self.args.weight_decay,
235 | },
236 | {
237 | "params": [
238 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
239 | ],
240 | "weight_decay": 0.0,
241 | },
242 | {
243 | "params": [
244 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
245 | ],
246 | "weight_decay": self.args.weight_decay,
247 | "lr": self.args.mm_projector_lr,
248 | },
249 | {
250 | "params": [
251 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
252 | ],
253 | "weight_decay": 0.0,
254 | "lr": self.args.mm_projector_lr,
255 | },
256 | ]
257 | else:
258 | optimizer_grouped_parameters = [
259 | {
260 | "params": [
261 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
262 | ],
263 | "weight_decay": self.args.weight_decay,
264 | },
265 | {
266 | "params": [
267 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
268 | ],
269 | "weight_decay": 0.0,
270 | },
271 | ]
272 |
273 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
274 | self.args)
275 |
276 | self.optimizer = optimizer_cls(
277 | optimizer_grouped_parameters, **optimizer_kwargs)
278 | if optimizer_cls.__name__ == "Adam8bit":
279 | import bitsandbytes
280 |
281 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
282 |
283 | skipped = 0
284 | for module in opt_model.modules():
285 | if isinstance(module, nn.Embedding):
286 | skipped += sum({p.data_ptr(): p.numel()
287 | for p in module.parameters()}.values())
288 | logger.info(
289 | f"skipped {module}: {skipped/2**20}M params")
290 | manager.register_module_override(
291 | module, "weight", {"optim_bits": 32})
292 | logger.debug(
293 | f"bitsandbytes: will optimize {module} in fp32")
294 | logger.info(f"skipped: {skipped/2**20}M params")
295 |
296 | return self.optimizer
297 |
298 | def _save_checkpoint(self, model, trial, metrics=None):
299 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
300 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
301 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
302 |
303 | run_dir = self._get_output_dir(trial=trial)
304 | output_dir = os.path.join(run_dir, checkpoint_folder)
305 |
306 | # Only save Adapter
307 | keys_to_match = ['mm_projector', 'vision_resampler']
308 | if getattr(self.args, "use_im_start_end", False):
309 | keys_to_match.extend(['embed_tokens', 'embed_in'])
310 |
311 | weight_to_save = get_mm_adapter_state_maybe_zero_3(
312 | self.model.named_parameters(), keys_to_match)
313 |
314 | if self.args.local_rank == 0 or self.args.local_rank == -1:
315 | self.model.config.save_pretrained(output_dir)
316 | torch.save(weight_to_save, os.path.join(
317 | output_dir, 'mm_projector.bin'))
318 | else:
319 | super(VisionLLaVATrainer, self)._save_checkpoint(
320 | model, trial, metrics)
321 |
322 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
323 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
324 | pass
325 | else:
326 | super(VisionLLaVATrainer, self)._save(output_dir, state_dict)
327 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "0.0.1"
8 | description = "Advanced large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.1.2", "torchvision==0.16.2",
17 | "transformers==4.39.3", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==0.27.2", "peft", "bitsandbytes",
19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20 | "gradio==4.16.0", "gradio_client==0.8.1",
21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23 | ]
24 |
25 | [project.optional-dependencies]
26 | train = ["deepspeed==0.13.1", "ninja", "wandb"]
27 | build = ["build", "twine"]
28 |
29 | [tool.setuptools.packages.find]
30 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
31 |
32 | [tool.wheel]
33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
--------------------------------------------------------------------------------
/scripts/eval-lmms.sh:
--------------------------------------------------------------------------------
1 | # you can refer to this issue: https://github.com/EvolvingLMMs-Lab/lmms-eval/issues/15
2 |
3 | export HF_HOME=/path-to-save-dir
4 |
5 | path=/path-to-your-model
6 |
7 | accelerate launch --num_processes=8 -m lmms_eval \
8 | --model llava \
9 | --model_args pretrained="${path}" \
10 | --tasks mme \
11 | --batch_size 1 \
12 | --log_samples --log_samples_suffix convllava \
13 | --output_path ./logs/
--------------------------------------------------------------------------------
/scripts/evaluation.sh:
--------------------------------------------------------------------------------
1 | export LMUData=path_to_your_saved_data
2 | export OMP_NUM_THREADS=1
3 |
4 | eval_dataset="MMBench_DEV_EN"
5 |
6 | llava_path=path_to_your_weights
7 |
8 | work_dir=path_to_your_work_dir
9 | gpu=2
10 |
11 | # if you want to use chatgpt to evaluate, you need to set OPENAI_API_KEY
12 | # export OPENAI_API_KEY="sk-1234"
13 |
14 | torchrun --nproc-per-node=${gpu} llava/eval/run.py \
15 | --data ${eval_dataset} \
16 | --model llava_v1.5_7b \
17 | --verbose \
18 | --work-dir ${work_dir}/vlmeval \
19 | --llava-path=${llava_path}
20 |
21 |
22 |
--------------------------------------------------------------------------------
/scripts/refcoco.sh:
--------------------------------------------------------------------------------
1 | CHECKPOINT=/path/to/convllava
2 |
3 | torchrun \
4 | --nnodes=1 \
5 | --node_rank=0 \
6 | --master_addr=127.0.0.1 \
7 | --nproc_per_node=8 \
8 | --master_port=25908 \
9 | llava/eval/evaluate_grounding.py --checkpoint ${CHECKPOINT} --out-dir output
10 |
11 |
12 |
--------------------------------------------------------------------------------
/scripts/stage_1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir
3 |
4 | deepspeed llava/train/train.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path lmsys/vicuna-7b-v1.5 \
7 | --version v1 \
8 | --dataset "dataset_you_want_train" \
9 | --vision_tower path_to_original_convnext \
10 | --mm_vision_resolution 768 \
11 | --mm_projector_type mlp2x_gelu \
12 | --freeze_backbone True \
13 | --vision_add_five_stage 6 \
14 | --vision_five_stage_width 3072 \
15 | --mm_vision_select_layer -1 \
16 | --mm_use_im_start_end False \
17 | --mm_use_im_patch_token False \
18 | --bf16 True \
19 | --output_dir ./checkpoints/convllava/stage1 \
20 | --num_train_epochs 1 \
21 | --per_device_train_batch_size 32 \
22 | --per_device_eval_batch_size 4 \
23 | --gradient_accumulation_steps 1 \
24 | --image_aspect_ratio pad \
25 | --group_by_modality_length True \
26 | --evaluation_strategy "no" \
27 | --save_strategy "steps" \
28 | --save_steps 24000 \
29 | --save_total_limit 1 \
30 | --learning_rate 3e-4 \
31 | --weight_decay 0. \
32 | --warmup_ratio 0.03 \
33 | --lr_scheduler_type "cosine" \
34 | --logging_steps 1 \
35 | --tf32 True \
36 | --model_max_length 2048 \
37 | --gradient_checkpointing True \
38 | --dataloader_num_workers 4 \
39 | --lazy_preprocess True \
40 | --report_to wandb
41 |
--------------------------------------------------------------------------------
/scripts/stage_2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir
3 |
4 | deepspeed llava/train/train.py \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path path_to_stage1_llm \
7 | --version v1 \
8 | --dataset "dataset_you_want_train" \
9 | --vision_tower path_to_stage1_convnext\
10 | --mm_projector_type mlp2x_gelu \
11 | --mm_vision_resolution 768 \
12 | --mm_vision_select_layer -1 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoints/convllava/stage2 \
17 | --tune_vision_tower True \
18 | --tune_vit_from_layer 2 \
19 | --tune_entire_model False \
20 | --num_train_epochs 1 \
21 | --per_device_train_batch_size 32 \
22 | --per_device_eval_batch_size 4 \
23 | --gradient_accumulation_steps 1 \
24 | --image_aspect_ratio pad \
25 | --group_by_modality_length True \
26 | --evaluation_strategy "no" \
27 | --save_strategy "steps" \
28 | --save_steps 24000 \
29 | --save_total_limit 1 \
30 | --learning_rate 1e-3 \
31 | --weight_decay 0. \
32 | --warmup_ratio 0.03 \
33 | --lr_scheduler_type "cosine" \
34 | --logging_steps 1 \
35 | --tf32 True \
36 | --model_max_length 2048 \
37 | --gradient_checkpointing True \
38 | --dataloader_num_workers 4 \
39 | --lazy_preprocess True \
40 | --report_to wandb
41 |
--------------------------------------------------------------------------------
/scripts/stage_3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # You may need to modify: model_name_or_path, dataset, vision_tower, output_dir
3 |
4 | deepspeed llava/train/train.py \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path path_to_stage2_llm \
7 | --version v1 \
8 | --dataset "dataset_you_want_train" \
9 | --vision_tower path_to_stage2_convnext \
10 | --mm_projector_type mlp2x_gelu \
11 | --mm_vision_resolution 768 \
12 | --mm_vision_select_layer -1 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoints/convllava/stage3 \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 16 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --image_aspect_ratio pad \
22 | --group_by_modality_length True \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 50000 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-5 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.03 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 4 \
36 | --lazy_preprocess True \
37 | --report_to wandb
38 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------