├── .gitignore
├── ACKNOWLEDGEMENTS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── LICENSE_MODEL
├── README.md
├── app
├── Configuration
│ └── Build.xcconfig
├── FastVLM App
│ ├── Assets.xcassets
│ │ ├── AccentColor.colorset
│ │ │ └── Contents.json
│ │ ├── AppIcon.appiconset
│ │ │ ├── Contents.json
│ │ │ ├── FastVLM - 150 Blue - Dark@2x.png
│ │ │ ├── FastVLM - 150 Blue - Light@2x.png
│ │ │ ├── FastVLM - 150 White - Tinted@2x.png
│ │ │ ├── FastVLM - MacOS - Dark@1x.png
│ │ │ └── FastVLM - MacOS - Dark@2x.png
│ │ └── Contents.json
│ ├── ContentView.swift
│ ├── FastVLM.entitlements
│ ├── FastVLMApp.swift
│ ├── FastVLMModel.swift
│ ├── Info.plist
│ ├── InfoView.swift
│ └── Preview Content
│ │ └── Preview Assets.xcassets
│ │ └── Contents.json
├── FastVLM.xcodeproj
│ ├── project.pbxproj
│ └── xcshareddata
│ │ └── xcschemes
│ │ └── FastVLM App.xcscheme
├── FastVLM
│ ├── FastVLM.h
│ ├── FastVLM.swift
│ └── MediaProcessingExtensions.swift
├── README.md
├── Video
│ ├── CameraController.swift
│ ├── CameraControlsView.swift
│ ├── CameraType.swift
│ ├── Video.h
│ └── VideoFrameView.swift
└── get_pretrained_mlx_model.sh
├── docs
├── acc_vs_latency_qwen-2.png
├── fastvlm-counting.gif
├── fastvlm-emoji.gif
├── fastvlm-flexible_prompts.png
└── fastvlm-handwriting.gif
├── get_models.sh
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mistral.py
│ │ ├── llava_mpt.py
│ │ └── llava_qwen.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip_encoder.py
│ │ ├── mobileclip
│ │ │ ├── __init__.py
│ │ │ ├── configs
│ │ │ │ └── mobileclip_l.json
│ │ │ └── mci.py
│ │ └── mobileclip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ ├── sglang_worker.py
│ └── test_message.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_xformers_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ ├── train_mem.py
│ ├── train_qwen.py
│ └── train_xformers.py
└── utils.py
├── model_export
├── README.md
├── export_vision_encoder.py
└── fastvlm_mlx-vlm.patch
├── predict.py
└── pyproject.toml
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # macOS
132 | **/.DS_Store
133 |
134 | # PyCharm project settings
135 | .idea/
136 |
137 | # Xcode
138 | *.xcworkspace
139 |
140 | # FastVLM models
141 | app/FastVLM/model
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
12 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
41 | -------------------------------------------------------------------------------
42 | SOFTWARE DISTRIBUTED WITH ML-FASTVLM:
43 |
44 | The ml-fastvlm software includes a number of subcomponents with separate
45 | copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
46 |
47 | The ml-fastvlm model weights copyright and license terms can be
48 | found in LICENSE_MODEL file.
49 | -------------------------------------------------------------------------------
50 |
--------------------------------------------------------------------------------
/LICENSE_MODEL:
--------------------------------------------------------------------------------
1 | Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is
2 | specifically developed and released by Apple Inc. ("Apple") for the sole purpose
3 | of scientific research of artificial intelligence and machine-learning
4 | technology. “Apple Machine Learning Research Model” means the model, including
5 | but not limited to algorithms, formulas, trained model weights, parameters,
6 | configurations, checkpoints, and any related materials (including
7 | documentation).
8 |
9 | This Apple Machine Learning Research Model is provided to You by
10 | Apple in consideration of your agreement to the following terms, and your use,
11 | modification, creation of Model Derivatives, and or redistribution of the Apple
12 | Machine Learning Research Model constitutes acceptance of this Agreement. If You
13 | do not agree with these terms, please do not use, modify, create Model
14 | Derivatives of, or distribute this Apple Machine Learning Research Model or
15 | Model Derivatives.
16 |
17 | * License Scope: In consideration of your agreement to abide by the following
18 | terms, and subject to these terms, Apple hereby grants you a personal,
19 | non-exclusive, worldwide, non-transferable, royalty-free, revocable, and
20 | limited license, to use, copy, modify, distribute, and create Model
21 | Derivatives (defined below) of the Apple Machine Learning Research Model
22 | exclusively for Research Purposes. You agree that any Model Derivatives You
23 | may create or that may be created for You will be limited to Research Purposes
24 | as well. “Research Purposes” means non-commercial scientific research and
25 | academic development activities, such as experimentation, analysis, testing
26 | conducted by You with the sole intent to advance scientific knowledge and
27 | research. “Research Purposes” does not include any commercial exploitation,
28 | product development or use in any commercial product or service.
29 |
30 | * Distribution of Apple Machine Learning Research Model and Model Derivatives:
31 | If you choose to redistribute Apple Machine Learning Research Model or its
32 | Model Derivatives, you must provide a copy of this Agreement to such third
33 | party, and ensure that the following attribution notice be provided: “Apple
34 | Machine Learning Research Model is licensed under the Apple Machine Learning
35 | Research Model License Agreement.” Additionally, all Model Derivatives must
36 | clearly be identified as such, including disclosure of modifications and
37 | changes made to the Apple Machine Learning Research Model. The name,
38 | trademarks, service marks or logos of Apple may not be used to endorse or
39 | promote Model Derivatives or the relationship between You and Apple. “Model
40 | Derivatives” means any models or any other artifacts created by modifications,
41 | improvements, adaptations, alterations to the architecture, algorithm or
42 | training processes of the Apple Machine Learning Research Model, or by any
43 | retraining, fine-tuning of the Apple Machine Learning Research Model.
44 |
45 | * No Other License: Except as expressly stated in this notice, no other rights
46 | or licenses, express or implied, are granted by Apple herein, including but
47 | not limited to any patent, trademark, and similar intellectual property rights
48 | worldwide that may be infringed by the Apple Machine Learning Research Model,
49 | the Model Derivatives or by other works in which the Apple Machine Learning
50 | Research Model may be incorporated.
51 |
52 | * Compliance with Laws: Your use of Apple Machine Learning Research Model must
53 | be in compliance with all applicable laws and regulations.
54 |
55 | * Term and Termination: The term of this Agreement will begin upon your
56 | acceptance of this Agreement or use of the Apple Machine Learning Research
57 | Model and will continue until terminated in accordance with the following
58 | terms. Apple may terminate this Agreement at any time if You are in breach of
59 | any term or condition of this Agreement. Upon termination of this Agreement,
60 | You must cease to use all Apple Machine Learning Research Models and Model
61 | Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will
62 | survive termination.
63 |
64 | * Disclaimer and Limitation of Liability: This Apple Machine Learning Research
65 | Model and any outputs generated by the Apple Machine Learning Research Model
66 | are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR
67 | IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF
68 | NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE,
69 | REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY
70 | THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for
71 | determining the appropriateness of using or redistributing the Apple Machine
72 | Learning Research Model and any outputs of the Apple Machine Learning Research
73 | Model and assume any risks associated with Your use of the Apple Machine
74 | Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE
75 | LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
76 | IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF
77 | THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE
78 | LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT,
79 | TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS
80 | BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
81 |
82 | * Governing Law: This Agreement will be governed by and construed under the laws
83 | of the State of California without regard to its choice of law principles. The
84 | Convention on Contracts for the International Sale of Goods shall not apply to
85 | the Agreement except that the arbitration clause and any arbitration hereunder
86 | shall be governed by the Federal Arbitration Act, Chapters 1 and 2.
87 |
88 | Copyright (C) 2025 Apple Inc. All Rights Reserved.
89 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FastVLM: Efficient Vision Encoding for Vision Language Models
2 |
3 | This is the official repository of
4 | **[FastVLM: Efficient Vision Encoding for Vision Language Models](https://www.arxiv.org/abs/2412.13303). (CVPR 2025)**
5 |
6 | [//]: # ()
7 |
8 |
9 |
10 |
11 | ### Highlights
12 | * We introduce FastViTHD, a novel hybrid vision encoder designed to output fewer tokens and significantly reduce encoding time for high-resolution images.
13 | * Our smallest variant outperforms LLaVA-OneVision-0.5B with 85x faster Time-to-First-Token (TTFT) and 3.4x smaller vision encoder.
14 | * Our larger variants using Qwen2-7B LLM outperform recent works like Cambrian-1-8B while using a single image encoder with a 7.9x faster TTFT.
15 | * Demo iOS app to demonstrate the performance of our model on a mobile device.
16 |
17 |
18 |
19 |  |
20 |  |
21 |  |
22 |
23 |
24 |
25 | ## Getting Started
26 | We use LLaVA codebase to train FastVLM variants. In order to train or finetune your own variants,
27 | please follow instructions provided in [LLaVA](https://github.com/haotian-liu/LLaVA) codebase.
28 | We provide instructions for running inference with our models.
29 |
30 | ### Setup
31 | ```bash
32 | conda create -n fastvlm python=3.10
33 | conda activate fastvlm
34 | pip install -e .
35 | ```
36 |
37 | ### Model Zoo
38 | For detailed information on various evaluations, please refer to our [paper](https://www.arxiv.org/abs/2412.13303).
39 |
40 | | Model | Stage | Pytorch Checkpoint (url) |
41 | |:-------------|:-----:|:---------------------------------------------------------------------------------------------------------------:|
42 | | FastVLM-0.5B | 2 | [fastvlm_0.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip) |
43 | | | 3 | [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip) |
44 | | FastVLM-1.5B | 2 | [fastvlm_1.5b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip) |
45 | | | 3 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip) |
46 | | FastVLM-7B | 2 | [fastvlm_7b_stage2](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip) |
47 | | | 3 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip) |
48 |
49 | To download all the pretrained checkpoints run the command below (note that this might take some time depending on your connection so might be good to grab ☕️ while you wait).
50 |
51 | ```bash
52 | bash get_models.sh # Files will be downloaded to `checkpoints` directory.
53 | ```
54 |
55 | ### Usage Example
56 | To run inference of PyTorch checkpoint, follow the instruction below
57 | ```bash
58 | python predict.py --model-path /path/to/checkpoint-dir \
59 | --image-file /path/to/image.png \
60 | --prompt "Describe the image."
61 | ```
62 |
63 | ### Inference on Apple Silicon
64 | To run inference on Apple Silicon, pytorch checkpoints have to be exported to format
65 | suitable for running on Apple Silicon, detailed instructions and code can be found [`model_export`](model_export/) subfolder.
66 | Please see the README there for more details.
67 |
68 | For convenience, we provide 3 models that are in Apple Silicon compatible format: [fastvlm_0.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3_llm.fp16.zip),
69 | [fastvlm_1.5b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3_llm.int8.zip),
70 | [fastvlm_7b_stage3](https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3_llm.int4.zip).
71 | We encourage developers to export the model of their choice with the appropriate quantization levels following
72 | the instructions in [`model_export`](model_export/).
73 |
74 | ### Inference on Apple Devices
75 | To run inference on Apple devices like iPhone, iPad or Mac, see [`app`](app/) subfolder for more details.
76 |
77 | ## Citation
78 | If you found this code useful, please cite the following paper:
79 | ```
80 | @InProceedings{fastvlm2025,
81 | author = {Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari},
82 | title = {FastVLM: Efficient Vision Encoding for Vision Language Models},
83 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
84 | month = {June},
85 | year = {2025},
86 | }
87 | ```
88 |
89 | ## Acknowledgements
90 | Our codebase is built using multiple opensource contributions, please see [ACKNOWLEDGEMENTS](ACKNOWLEDGEMENTS) for more details.
91 |
92 | ## License
93 | Please check out the repository [LICENSE](LICENSE) before using the provided code and
94 | [LICENSE_MODEL](LICENSE_MODEL) for the released models.
95 |
--------------------------------------------------------------------------------
/app/Configuration/Build.xcconfig:
--------------------------------------------------------------------------------
1 | // The `DISAMBIGUATOR` configuration is to make it easier to build
2 | // and run a sample code project. Once you set your project's development team,
3 | // you'll have a unique bundle identifier. This is because the bundle identifier
4 | // is derived based on the 'DISAMBIGUATOR' value. Do not use this
5 | // approach in your own projects—it's only useful for example projects because
6 | // they are frequently downloaded and don't have a development team set.
7 | DISAMBIGUATOR=${DEVELOPMENT_TEAM}
8 |
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AccentColor.colorset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "colors" : [
3 | {
4 | "idiom" : "universal"
5 | }
6 | ],
7 | "info" : {
8 | "author" : "xcode",
9 | "version" : 1
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "images" : [
3 | {
4 | "filename" : "FastVLM - 150 Blue - Light@2x.png",
5 | "idiom" : "universal",
6 | "platform" : "ios",
7 | "size" : "1024x1024"
8 | },
9 | {
10 | "appearances" : [
11 | {
12 | "appearance" : "luminosity",
13 | "value" : "dark"
14 | }
15 | ],
16 | "filename" : "FastVLM - 150 Blue - Dark@2x.png",
17 | "idiom" : "universal",
18 | "platform" : "ios",
19 | "size" : "1024x1024"
20 | },
21 | {
22 | "appearances" : [
23 | {
24 | "appearance" : "luminosity",
25 | "value" : "tinted"
26 | }
27 | ],
28 | "filename" : "FastVLM - 150 White - Tinted@2x.png",
29 | "idiom" : "universal",
30 | "platform" : "ios",
31 | "size" : "1024x1024"
32 | },
33 | {
34 | "idiom" : "mac",
35 | "scale" : "1x",
36 | "size" : "16x16"
37 | },
38 | {
39 | "idiom" : "mac",
40 | "scale" : "2x",
41 | "size" : "16x16"
42 | },
43 | {
44 | "idiom" : "mac",
45 | "scale" : "1x",
46 | "size" : "32x32"
47 | },
48 | {
49 | "idiom" : "mac",
50 | "scale" : "2x",
51 | "size" : "32x32"
52 | },
53 | {
54 | "idiom" : "mac",
55 | "scale" : "1x",
56 | "size" : "128x128"
57 | },
58 | {
59 | "idiom" : "mac",
60 | "scale" : "2x",
61 | "size" : "128x128"
62 | },
63 | {
64 | "idiom" : "mac",
65 | "scale" : "1x",
66 | "size" : "256x256"
67 | },
68 | {
69 | "idiom" : "mac",
70 | "scale" : "2x",
71 | "size" : "256x256"
72 | },
73 | {
74 | "filename" : "FastVLM - MacOS - Dark@1x.png",
75 | "idiom" : "mac",
76 | "scale" : "1x",
77 | "size" : "512x512"
78 | },
79 | {
80 | "filename" : "FastVLM - MacOS - Dark@2x.png",
81 | "idiom" : "mac",
82 | "scale" : "2x",
83 | "size" : "512x512"
84 | }
85 | ],
86 | "info" : {
87 | "author" : "xcode",
88 | "version" : 1
89 | }
90 | }
91 |
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Dark@2x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Dark@2x.png
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Light@2x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 Blue - Light@2x.png
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 White - Tinted@2x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - 150 White - Tinted@2x.png
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@1x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@1x.png
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@2x.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/app/FastVLM App/Assets.xcassets/AppIcon.appiconset/FastVLM - MacOS - Dark@2x.png
--------------------------------------------------------------------------------
/app/FastVLM App/Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/app/FastVLM App/FastVLM.entitlements:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | com.apple.developer.kernel.increased-memory-limit
6 |
7 | com.apple.security.app-sandbox
8 |
9 | com.apple.security.device.camera
10 |
11 | com.apple.security.files.user-selected.read-only
12 |
13 | com.apple.security.network.client
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/app/FastVLM App/FastVLMApp.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import SwiftUI
7 |
8 | @main
9 | struct FastVLMApp: App {
10 | var body: some Scene {
11 | WindowGroup {
12 | ContentView()
13 | }
14 | }
15 | }
16 |
--------------------------------------------------------------------------------
/app/FastVLM App/FastVLMModel.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import CoreImage
7 | import FastVLM
8 | import Foundation
9 | import MLX
10 | import MLXLMCommon
11 | import MLXRandom
12 | import MLXVLM
13 |
14 | @Observable
15 | @MainActor
16 | class FastVLMModel {
17 |
18 | public var running = false
19 | public var modelInfo = ""
20 | public var output = ""
21 | public var promptTime: String = ""
22 |
23 | enum LoadState {
24 | case idle
25 | case loaded(ModelContainer)
26 | }
27 |
28 | private let modelConfiguration = FastVLM.modelConfiguration
29 |
30 | /// parameters controlling the output
31 | let generateParameters = GenerateParameters(temperature: 0.0)
32 | let maxTokens = 240
33 |
34 | /// update the display every N tokens -- 4 looks like it updates continuously
35 | /// and is low overhead. observed ~15% reduction in tokens/s when updating
36 | /// on every token
37 | let displayEveryNTokens = 4
38 |
39 | private var loadState = LoadState.idle
40 | private var currentTask: Task?
41 |
42 | enum EvaluationState: String, CaseIterable {
43 | case idle = "Idle"
44 | case processingPrompt = "Processing Prompt"
45 | case generatingResponse = "Generating Response"
46 | }
47 |
48 | public var evaluationState = EvaluationState.idle
49 |
50 | public init() {
51 | FastVLM.register(modelFactory: VLMModelFactory.shared)
52 | }
53 |
54 | private func _load() async throws -> ModelContainer {
55 | switch loadState {
56 | case .idle:
57 | // limit the buffer cache
58 | MLX.GPU.set(cacheLimit: 20 * 1024 * 1024)
59 |
60 | let modelContainer = try await VLMModelFactory.shared.loadContainer(
61 | configuration: modelConfiguration
62 | ) {
63 | [modelConfiguration] progress in
64 | Task { @MainActor in
65 | self.modelInfo =
66 | "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%"
67 | }
68 | }
69 | self.modelInfo = "Loaded"
70 | loadState = .loaded(modelContainer)
71 | return modelContainer
72 |
73 | case .loaded(let modelContainer):
74 | return modelContainer
75 | }
76 | }
77 |
78 | public func load() async {
79 | do {
80 | _ = try await _load()
81 | } catch {
82 | self.modelInfo = "Error loading model: \(error)"
83 | }
84 | }
85 |
86 | public func generate(_ userInput: UserInput) async -> Task {
87 | if let currentTask, running {
88 | return currentTask
89 | }
90 |
91 | running = true
92 |
93 | // Cancel any existing task
94 | currentTask?.cancel()
95 |
96 | // Create new task and store reference
97 | let task = Task {
98 | do {
99 | let modelContainer = try await _load()
100 |
101 | // each time you generate you will get something new
102 | MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000))
103 |
104 | // Check if task was cancelled
105 | if Task.isCancelled { return }
106 |
107 | let result = try await modelContainer.perform { context in
108 | // Measure the time it takes to prepare the input
109 |
110 | Task { @MainActor in
111 | evaluationState = .processingPrompt
112 | }
113 |
114 | let llmStart = Date()
115 | let input = try await context.processor.prepare(input: userInput)
116 |
117 | var seenFirstToken = false
118 |
119 | // FastVLM generates the output
120 | let result = try MLXLMCommon.generate(
121 | input: input, parameters: generateParameters, context: context
122 | ) { tokens in
123 | // Check if task was cancelled
124 | if Task.isCancelled {
125 | return .stop
126 | }
127 |
128 | if !seenFirstToken {
129 | seenFirstToken = true
130 |
131 | // produced first token, update the time to first token,
132 | // the processing state and start displaying the text
133 | let llmDuration = Date().timeIntervalSince(llmStart)
134 | let text = context.tokenizer.decode(tokens: tokens)
135 | Task { @MainActor in
136 | evaluationState = .generatingResponse
137 | self.output = text
138 | self.promptTime = "\(Int(llmDuration * 1000)) ms"
139 | }
140 | }
141 |
142 | // Show the text in the view as it generates
143 | if tokens.count % displayEveryNTokens == 0 {
144 | let text = context.tokenizer.decode(tokens: tokens)
145 | Task { @MainActor in
146 | self.output = text
147 | }
148 | }
149 |
150 | if tokens.count >= maxTokens {
151 | return .stop
152 | } else {
153 | return .more
154 | }
155 | }
156 |
157 | // Return the duration of the LLM and the result
158 | return result
159 | }
160 |
161 | // Check if task was cancelled before updating UI
162 | if !Task.isCancelled {
163 | self.output = result.output
164 | }
165 |
166 | } catch {
167 | if !Task.isCancelled {
168 | output = "Failed: \(error)"
169 | }
170 | }
171 |
172 | if evaluationState == .generatingResponse {
173 | evaluationState = .idle
174 | }
175 |
176 | running = false
177 | }
178 |
179 | currentTask = task
180 | return task
181 | }
182 |
183 | public func cancel() {
184 | currentTask?.cancel()
185 | currentTask = nil
186 | running = false
187 | output = ""
188 | promptTime = ""
189 | }
190 | }
191 |
--------------------------------------------------------------------------------
/app/FastVLM App/Info.plist:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/app/FastVLM App/InfoView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 | import SwiftUI
8 |
9 | struct InfoView: View {
10 | @Environment(\.dismiss) var dismiss
11 |
12 | let paragraph1 = "**FastVLM¹** is a new family of Vision-Language models that makes use of **FastViTHD**, a hierarchical hybrid vision encoder that produces small number of high quality tokens at low latencies, resulting in significantly faster time-to-first-token (TTFT)."
13 | let paragraph2 = "This app showcases the **FastVLM** model in action, allowing users to freely customize the prompt. FastVLM utilizes Qwen2-Instruct LLMs without additional safety tuning, so please exercise caution when modifying the prompt."
14 | let footer = "1. **FastVLM: Efficient Vision Encoding for Vision Language Models.** (CVPR 2025) Pavan Kumar Anasosalu Vasu, Fartash Faghri, Chun-Liang Li, Cem Koc, Nate True, Albert Antony, Gokul Santhanam, James Gabriel, Peter Grasch, Oncel Tuzel, Hadi Pouransari"
15 |
16 | var body: some View {
17 | NavigationStack {
18 | VStack(alignment: .leading, spacing: 20.0) {
19 | // I'm not going to lie, this doesn't make sense...
20 | // Wrapping `String`s with `.init()` turns them into `LocalizedStringKey`s
21 | // which gives us all of the fun Markdown formatting while retaining the
22 | // ability to use `String` variables. ¯\_(ツ)_/¯
23 | Text("\(.init(paragraph1))\n\n\(.init(paragraph2))\n\n")
24 | .font(.body)
25 |
26 | Spacer()
27 |
28 | Text(.init(footer))
29 | .font(.caption)
30 | .foregroundStyle(.secondary)
31 | }
32 | .padding()
33 | .frame(maxWidth: .infinity, maxHeight: .infinity, alignment: .top)
34 | .textSelection(.enabled)
35 | .navigationTitle("Information")
36 | #if os(iOS)
37 | .navigationBarTitleDisplayMode(.inline)
38 | #endif
39 | .toolbar {
40 | #if os(iOS)
41 | ToolbarItem(placement: .navigationBarLeading) {
42 | Button {
43 | dismiss()
44 | } label: {
45 | Image(systemName: "xmark.circle")
46 | .resizable()
47 | .frame(width: 25, height: 25)
48 | .foregroundStyle(.secondary)
49 | }
50 | .buttonStyle(.plain)
51 | }
52 | #elseif os(macOS)
53 | ToolbarItem(placement: .cancellationAction) {
54 | Button("Done") {
55 | dismiss()
56 | }
57 | .buttonStyle(.bordered)
58 | }
59 | #endif
60 | }
61 | }
62 | }
63 | }
64 |
65 | #Preview {
66 | InfoView()
67 | }
68 |
--------------------------------------------------------------------------------
/app/FastVLM App/Preview Content/Preview Assets.xcassets/Contents.json:
--------------------------------------------------------------------------------
1 | {
2 | "info" : {
3 | "author" : "xcode",
4 | "version" : 1
5 | }
6 | }
7 |
--------------------------------------------------------------------------------
/app/FastVLM.xcodeproj/xcshareddata/xcschemes/FastVLM App.xcscheme:
--------------------------------------------------------------------------------
1 |
2 |
5 |
9 |
10 |
16 |
22 |
23 |
24 |
25 |
26 |
32 |
33 |
43 |
45 |
51 |
52 |
53 |
54 |
60 |
62 |
68 |
69 |
70 |
71 |
73 |
74 |
77 |
78 |
79 |
--------------------------------------------------------------------------------
/app/FastVLM/FastVLM.h:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | #ifndef FastVLM_h
7 | #define FastVLM_h
8 |
9 |
10 | #endif /* FastVLM_h */
11 |
--------------------------------------------------------------------------------
/app/FastVLM/MediaProcessingExtensions.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Accelerate
7 | import CoreImage
8 | import MLX
9 | import MLXLMCommon
10 | import MLXVLM
11 |
12 | /// Additions to MediaProcessing -- not currently present in mlx-libraries
13 | enum MediaProcessingExtensions {
14 |
15 | // this function is not exported in current mlx-swift-examples -- local copy until it is exposed
16 | // properly
17 | public static func apply(_ image: CIImage, processing: UserInput.Processing?) -> CIImage {
18 | var image = image
19 |
20 | if let resize = processing?.resize {
21 | let scale = MediaProcessing.bestFitScale(image.extent.size, in: resize)
22 | image = image.transformed(by: CGAffineTransform(scaleX: scale, y: scale))
23 | }
24 |
25 | return image
26 | }
27 |
28 | public static func rectSmallerOrEqual(_ extent: CGRect, size: CGSize) -> Bool {
29 | return extent.width <= size.width && extent.height <= size.height
30 | }
31 |
32 | public static func centerCrop(_ extent: CGRect, size: CGSize) -> CGRect {
33 | let targetWidth = min(extent.width, size.width)
34 | let targetHeight = min(extent.height, size.height)
35 |
36 | return CGRect(
37 | x: (extent.maxX - targetWidth) / 2,
38 | y: (extent.maxY - targetHeight) / 2,
39 | width: targetWidth, height: targetHeight
40 | )
41 | }
42 |
43 | public static func centerCrop(_ image: CIImage, size: CGSize) -> CIImage {
44 | let extent = image.extent
45 | if rectSmallerOrEqual(extent, size: size) {
46 | return image
47 | }
48 |
49 | let crop = centerCrop(extent, size: size)
50 | return
51 | image
52 | .cropped(to: crop)
53 | .transformed(by: CGAffineTransform(translationX: -crop.minX, y: -crop.minY))
54 | }
55 |
56 | public static func fitIn(_ size: CGSize, shortestEdge: Int) -> CGSize {
57 | let floatShortestEdge = CGFloat(shortestEdge)
58 |
59 | let (short, long) =
60 | size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
61 | let newShort = floatShortestEdge
62 | let newLong = floatShortestEdge * long / short
63 |
64 | return size.width <= size.height
65 | ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
66 | }
67 |
68 | public static func fitIn(_ size: CGSize, longestEdge: Int) -> CGSize {
69 | let floatLongestEdge = CGFloat(longestEdge)
70 |
71 | var (newShort, newLong) =
72 | size.width <= size.height ? (size.width, size.height) : (size.height, size.width)
73 |
74 | if newLong > floatLongestEdge {
75 | newLong = floatLongestEdge
76 | newShort = floatLongestEdge * newShort / newLong
77 | }
78 |
79 | return size.width <= size.height
80 | ? CGSize(width: newShort, height: newLong) : CGSize(width: newLong, height: newShort)
81 | }
82 |
83 | // version of function from https://github.com/ml-explore/mlx-swift-examples/pull/222
84 | public static func resampleBicubic(_ image: CIImage, to size: CGSize) -> CIImage {
85 | // Create a bicubic scale filter
86 |
87 | let yScale = size.height / image.extent.height
88 | let xScale = size.width / image.extent.width
89 |
90 | let filter = CIFilter.bicubicScaleTransform()
91 | filter.inputImage = image
92 | filter.scale = Float(yScale)
93 | filter.aspectRatio = Float(xScale / yScale)
94 | let scaledImage = filter.outputImage!
95 |
96 | // Create a rect with the exact dimensions we want
97 | let exactRect = CGRect(
98 | x: 0,
99 | y: 0,
100 | width: size.width,
101 | height: size.height
102 | )
103 | // Crop to ensure exact dimensions
104 | return scaledImage.cropped(to: exactRect)
105 | }
106 |
107 | static let context = CIContext()
108 |
109 | /// Convert the CIImage into a planar 3 channel MLXArray `[1, C, H, W]`.
110 | ///
111 | /// This physically moves the channels into a planar configuration -- this is
112 | /// required for feeding into the CoreML model and is faster to use
113 | /// dedicated functions than transforming into contiguous memory
114 | /// on readout.
115 | static public func asPlanarMLXArray(_ image: CIImage, colorSpace: CGColorSpace? = nil)
116 | -> MLXArray
117 | {
118 | let size = image.extent.size
119 | let w = Int(size.width.rounded())
120 | let h = Int(size.height.rounded())
121 |
122 | // probably not strictly necessary, but this is what happens in
123 | // e.g. image_processing_siglip in transformers (float32)
124 | let format = CIFormat.RGBAf
125 | let componentsPerPixel = 4
126 | let bytesPerComponent: Int = MemoryLayout.size
127 | let bytesPerPixel = componentsPerPixel * bytesPerComponent
128 | let bytesPerRow = w * bytesPerPixel
129 |
130 | var data = Data(count: w * h * bytesPerPixel)
131 | var planarData = Data(count: 3 * w * h * bytesPerComponent)
132 | data.withUnsafeMutableBytes { ptr in
133 | context.render(
134 | image, toBitmap: ptr.baseAddress!, rowBytes: bytesPerRow, bounds: image.extent,
135 | format: format, colorSpace: colorSpace)
136 | context.clearCaches()
137 |
138 | let vh = vImagePixelCount(h)
139 | let vw = vImagePixelCount(w)
140 |
141 | // convert from RGBAf -> RGBf in place
142 | let rgbBytesPerRow = w * 3 * bytesPerComponent
143 | var rgbaSrc = vImage_Buffer(
144 | data: ptr.baseAddress!, height: vh, width: vw, rowBytes: bytesPerRow)
145 | var rgbDest = vImage_Buffer(
146 | data: ptr.baseAddress!, height: vh, width: vw, rowBytes: rgbBytesPerRow)
147 |
148 | vImageConvert_RGBAFFFFtoRGBFFF(&rgbaSrc, &rgbDest, vImage_Flags(kvImageNoFlags))
149 |
150 | // and convert to planar data in a second buffer
151 | planarData.withUnsafeMutableBytes { planarPtr in
152 | let planeBytesPerRow = w * bytesPerComponent
153 |
154 | var rDest = vImage_Buffer(
155 | data: planarPtr.baseAddress!.advanced(by: 0 * planeBytesPerRow * h), height: vh,
156 | width: vw, rowBytes: planeBytesPerRow)
157 | var gDest = vImage_Buffer(
158 | data: planarPtr.baseAddress!.advanced(by: 1 * planeBytesPerRow * h), height: vh,
159 | width: vw, rowBytes: planeBytesPerRow)
160 | var bDest = vImage_Buffer(
161 | data: planarPtr.baseAddress!.advanced(by: 2 * planeBytesPerRow * h), height: vh,
162 | width: vw, rowBytes: planeBytesPerRow)
163 |
164 | vImageConvert_RGBFFFtoPlanarF(
165 | &rgbDest, &rDest, &gDest, &bDest, vImage_Flags(kvImageNoFlags))
166 | }
167 | }
168 |
169 | return MLXArray(planarData, [1, 3, h, w], type: Float32.self)
170 | }
171 |
172 | }
173 |
--------------------------------------------------------------------------------
/app/README.md:
--------------------------------------------------------------------------------
1 | # FastVLM
2 |
3 | Demonstrates the performance of **FastVLM** models for on-device, visual question answering.
4 |
5 |
6 |
7 |  |
8 |  |
9 |  |
10 |
11 |
12 |
13 | ## Features
14 |
15 | - FastVLM runs on iOS (18.2+) and macOS (15.2+).
16 | - View Time-To-First-Token (TTFT) with every inference.
17 | - All predictions are processed privately and securely using on-device models.
18 |
19 | ### Flexible Prompting
20 |
21 |
22 |
23 | The app includes a set of built-in prompts to help you get started quickly. Tap the **Prompts** button in the top-right corner to explore them. Selecting a prompt will immediately update the active input. To create new prompts or edit existing ones, choose **Customize…** from the **Prompts** menu.
24 |
25 | ## Pretrained Model Options
26 |
27 | There are 3 pretrained sizes of FastVLM to choose from:
28 |
29 | - **FastVLM 0.5B**: Small and fast - great for mobile devices where speed matters.
30 | - **FastVLM 1.5B**: Well balanced - great for larger devices where speed and accuracy matters.
31 | - **FastVLM 7B**: Fast and accurate - ideal for situations where accuracy matters over speed.
32 |
33 | To download any FastVLM listed above, use the [get_pretrained_mlx_model.sh](get_pretrained_mlx_model.sh) script. The script downloads the model from the web and places it in the appropriate location. Once a model has been downloaded using the steps below, no additional steps are needed to build the app in Xcode.
34 |
35 | To explore how the other models work for your use-case, simply re-run the `get_pretrained_mlx_model.sh` with the new model selected, follow the prompts, and rebuild your app in Xcode.
36 |
37 | ### Download Instructions
38 |
39 | 1. Make the script executable
40 |
41 | ```shell
42 | chmod +x app/get_pretrained_mlx_model.sh
43 | ```
44 |
45 | 2. Download FastVLM
46 |
47 | ```shell
48 | app/get_pretrained_mlx_model.sh --model 0.5b --dest app/FastVLM/model
49 | ```
50 |
51 | 3. Open the app in Xcode, Build, and Run.
52 |
53 | ### Custom Model
54 |
55 | In addition to pretrained sizes of FastVLM, you can further quantize or fine-tune FastVLM to best fit their needs. To learn more, check out our documentation on how to [`export the model`](../model_export#export-vlm).
56 | Please clear existing model in `app/FastVLM/model` before downloading or copying a new model.
57 |
--------------------------------------------------------------------------------
/app/Video/CameraController.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import AVFoundation
7 | import CoreImage
8 |
9 | #if os(iOS)
10 | import UIKit
11 | #endif
12 |
13 | @Observable
14 | public class CameraController: NSObject {
15 |
16 | private var framesContinuation: AsyncStream.Continuation?
17 |
18 | public var backCamera = true {
19 | didSet {
20 | stop()
21 | start()
22 | }
23 | }
24 |
25 | public var devices = [AVCaptureDevice]()
26 |
27 | public var device: AVCaptureDevice = AVCaptureDevice.default(for: .video)! {
28 | didSet {
29 | stop()
30 | start()
31 | }
32 | }
33 |
34 | private var permissionGranted = true
35 | private var captureSession: AVCaptureSession?
36 | private let sessionQueue = DispatchQueue(label: "sessionQueue")
37 | @objc dynamic private var rotationCoordinator : AVCaptureDevice.RotationCoordinator?
38 | private var rotationObservation: NSKeyValueObservation?
39 |
40 | public func attach(continuation: AsyncStream.Continuation) {
41 | sessionQueue.async {
42 | self.framesContinuation = continuation
43 | }
44 | }
45 |
46 | public func detatch() {
47 | sessionQueue.async {
48 | self.framesContinuation = nil
49 | }
50 | }
51 |
52 | public func stop() {
53 | sessionQueue.sync { [self] in
54 | captureSession?.stopRunning()
55 | captureSession = nil
56 | }
57 |
58 | }
59 |
60 | public func start() {
61 | sessionQueue.async { [self] in
62 | let captureSession = AVCaptureSession()
63 | self.captureSession = captureSession
64 |
65 | self.checkPermission()
66 | self.setupCaptureSession(position: backCamera ? .back : .front)
67 | captureSession.startRunning()
68 | }
69 | }
70 |
71 | #if os(iOS)
72 | private func setOrientation(_ orientation: UIDeviceOrientation) {
73 | guard let captureSession else { return }
74 |
75 | let angle: Double?
76 | switch orientation {
77 | case .unknown, .faceDown:
78 | angle = nil
79 | case .portrait, .faceUp:
80 | angle = 90
81 | case .portraitUpsideDown:
82 | angle = 270
83 | case .landscapeLeft:
84 | angle = 0
85 | case .landscapeRight:
86 | angle = 180
87 | @unknown default:
88 | angle = nil
89 | }
90 |
91 | if let angle {
92 | for output in captureSession.outputs {
93 | output.connection(with: .video)?.videoRotationAngle = angle
94 | }
95 | }
96 | }
97 |
98 | private func updateRotation(rotation : CGFloat) {
99 | guard let captureSession else { return }
100 | for output in captureSession.outputs {
101 | output.connection(with: .video)?.videoRotationAngle = rotation
102 | }
103 | }
104 | #endif
105 |
106 | func checkPermission() {
107 | switch AVCaptureDevice.authorizationStatus(for: .video) {
108 | case .authorized:
109 | // The user has previously granted access to the camera.
110 | self.permissionGranted = true
111 |
112 | case .notDetermined:
113 | // The user has not yet been asked for camera access.
114 | self.requestPermission()
115 |
116 | // Combine the two other cases into the default case
117 | default:
118 | self.permissionGranted = false
119 | }
120 | }
121 |
122 | func requestPermission() {
123 | // Strong reference not a problem here but might become one in the future.
124 | AVCaptureDevice.requestAccess(for: .video) { [unowned self] granted in
125 | self.permissionGranted = granted
126 | }
127 | }
128 |
129 | func setupCaptureSession(position: AVCaptureDevice.Position) {
130 | guard let captureSession else { return }
131 |
132 | let videoOutput = AVCaptureVideoDataOutput()
133 |
134 | guard permissionGranted else {
135 | print("No permission for camera")
136 | return
137 | }
138 |
139 | let deviceTypes: [AVCaptureDevice.DeviceType]
140 | #if os(iOS)
141 | deviceTypes = [.builtInDualCamera, .builtInWideAngleCamera]
142 | #else
143 | deviceTypes = [.external, .continuityCamera, .builtInWideAngleCamera]
144 | #endif
145 |
146 | let videoDeviceDiscoverySession = AVCaptureDevice.DiscoverySession(
147 | deviceTypes: deviceTypes,
148 | mediaType: .video,
149 | position: position)
150 |
151 | let videoDevice: AVCaptureDevice?
152 | if videoDeviceDiscoverySession.devices.contains(self.device) {
153 | videoDevice = self.device
154 | } else {
155 | videoDevice = videoDeviceDiscoverySession.devices.first
156 | }
157 |
158 | if devices.isEmpty {
159 | self.devices = videoDeviceDiscoverySession.devices
160 | }
161 |
162 | guard
163 | let videoDevice
164 | else {
165 | print("Unable to find video device")
166 | return
167 | }
168 | guard let videoDeviceInput = try? AVCaptureDeviceInput(device: videoDevice) else {
169 | print("Unable to create AVCaptureDeviceInput")
170 | return
171 | }
172 | guard captureSession.canAddInput(videoDeviceInput) else {
173 | print("Unable to add input")
174 | return
175 | }
176 | captureSession.addInput(videoDeviceInput)
177 |
178 | videoOutput.setSampleBufferDelegate(self, queue: DispatchQueue(label: "sampleBufferQueue"))
179 | captureSession.addOutput(videoOutput)
180 | captureSession.sessionPreset = AVCaptureSession.Preset.hd1920x1080
181 |
182 | #if os(iOS)
183 | rotationCoordinator = AVCaptureDevice.RotationCoordinator(device: videoDevice, previewLayer: nil)
184 | rotationObservation = observe(\.rotationCoordinator!.videoRotationAngleForHorizonLevelCapture, options: [.initial, .new]) { [weak self] _, change in
185 | if let nv = change.newValue {
186 | self?.updateRotation(rotation: nv)
187 | }
188 | }
189 | #endif
190 | }
191 | }
192 |
193 | extension CameraController: AVCaptureVideoDataOutputSampleBufferDelegate {
194 | public func captureOutput(
195 | _ output: AVCaptureOutput, didOutput sampleBuffer: CMSampleBuffer,
196 | from connection: AVCaptureConnection
197 | ) {
198 | if sampleBuffer.isValid && sampleBuffer.imageBuffer != nil {
199 | framesContinuation?.yield(sampleBuffer)
200 | }
201 | }
202 | }
203 |
--------------------------------------------------------------------------------
/app/Video/CameraControlsView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import AVFoundation
7 | import SwiftUI
8 |
9 | public struct CameraControlsView: View {
10 |
11 | @Binding public var backCamera: Bool
12 | @Binding public var device: AVCaptureDevice
13 | @Binding public var devices: [AVCaptureDevice]
14 |
15 | public init(
16 | backCamera: Binding,
17 | device: Binding,
18 | devices: Binding<[AVCaptureDevice]>
19 | ) {
20 | self._backCamera = backCamera
21 | self._device = device
22 | self._devices = devices
23 | }
24 |
25 | public var body: some View {
26 | Button {
27 | backCamera.toggle()
28 | } label: {
29 | RoundedRectangle(cornerRadius: 8.0)
30 | .fill(.regularMaterial)
31 | .frame(width: 32.0, height: 32.0)
32 | .overlay(alignment: .center) {
33 | // Switch cameras image
34 | Image(systemName: "arrow.triangle.2.circlepath.camera.fill")
35 | .foregroundStyle(.primary)
36 | .padding(6.0)
37 | }
38 | }
39 | }
40 | }
41 |
--------------------------------------------------------------------------------
/app/Video/CameraType.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import Foundation
7 |
8 | public enum CameraType: String, CaseIterable {
9 | case continuous
10 | case single
11 | }
12 |
--------------------------------------------------------------------------------
/app/Video/Video.h:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | #import
7 |
8 | //! Project version number for Video.
9 | FOUNDATION_EXPORT double VideoVersionNumber;
10 |
11 | //! Project version string for Video.
12 | FOUNDATION_EXPORT const unsigned char VideoVersionString[];
13 |
--------------------------------------------------------------------------------
/app/Video/VideoFrameView.swift:
--------------------------------------------------------------------------------
1 | //
2 | // For licensing see accompanying LICENSE file.
3 | // Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | //
5 |
6 | import AVFoundation
7 | import CoreImage
8 | import Foundation
9 | import SwiftUI
10 |
11 | /// Displays a stream of video frames
12 | public struct VideoFrameView: View {
13 | @Environment(\.colorScheme) private var colorScheme
14 |
15 | public let frames: AsyncStream
16 | public let cameraType: CameraType
17 | public let action: ((CVImageBuffer) -> Void)?
18 |
19 | @State private var hold: Bool = false
20 | @State private var videoFrame: CVImageBuffer?
21 |
22 | private var backgroundColor: Color {
23 | #if os(iOS)
24 | return Color(.secondarySystemBackground)
25 | #elseif os(macOS)
26 | return Color(.secondarySystemFill)
27 | #else
28 | // When in doubt, use these values that I captured to match iOS' secondarySystemBackground
29 | if colorScheme == .dark {
30 | return Color(red: 0.11, green: 0.11, blue: 0.12)
31 | } else {
32 | return Color(red: 0.95, green: 0.95, blue: 0.97)
33 | }
34 | #endif
35 | }
36 |
37 | public init(
38 | frames: AsyncStream,
39 | cameraType: CameraType,
40 | action: ((CVImageBuffer) -> Void)?
41 | ) {
42 | self.frames = frames
43 | self.cameraType = cameraType
44 | self.action = action
45 | }
46 |
47 | public var body: some View {
48 | Group {
49 | if let videoFrame {
50 | _ImageView(image: videoFrame)
51 | .overlay(alignment: .bottom) {
52 | if cameraType == .single {
53 | Button {
54 | tap()
55 | } label: {
56 | if hold {
57 | Label("Resume", systemImage: "play.fill")
58 | } else {
59 | Label("Capture Photo", systemImage: "camera.fill")
60 | }
61 | }
62 | .clipShape(.capsule)
63 | .buttonStyle(.borderedProminent)
64 | .tint(hold ? .gray : .accentColor)
65 | .foregroundColor(.white)
66 | .padding()
67 | }
68 | }
69 | } else {
70 | // spinner before the camera comes up
71 | ProgressView()
72 | .controlSize(.large)
73 | }
74 | }
75 | // This ensures that we take up the full 4/3 aspect ratio
76 | // even if we don't have an image to display
77 | .frame(maxWidth: .infinity, maxHeight: .infinity)
78 | .background(backgroundColor)
79 | .clipShape(RoundedRectangle(cornerRadius: 10.0))
80 | .task {
81 | // feed frames to the _ImageView
82 | if Task.isCancelled {
83 | return
84 | }
85 | for await frame in frames {
86 | if !hold {
87 | videoFrame = frame
88 | }
89 | }
90 | }
91 | .onChange(of: cameraType) { _, newType in
92 | // No matter what, when the user switches to .continuous,
93 | // we need to continue showing updated frames
94 | if newType == .continuous {
95 | hold = false
96 | }
97 | }
98 | }
99 |
100 | private func tap() {
101 | if hold {
102 | // resume
103 | hold = false
104 | } else if let videoFrame {
105 | hold = true
106 | if let action {
107 | action(videoFrame)
108 | }
109 | }
110 | }
111 | }
112 |
113 | #if os(iOS)
114 | /// Internal view to display a CVImageBuffer
115 | private struct _ImageView: UIViewRepresentable {
116 |
117 | let image: Any
118 | var gravity = CALayerContentsGravity.resizeAspectFill
119 |
120 | func makeUIView(context: Context) -> UIView {
121 | let view = UIView()
122 | view.layer.contentsGravity = gravity
123 | return view
124 | }
125 |
126 | func updateUIView(_ uiView: UIView, context: Context) {
127 | uiView.layer.contents = image
128 | }
129 | }
130 | #else
131 | private struct _ImageView: NSViewRepresentable {
132 |
133 | let image: Any
134 | var gravity = CALayerContentsGravity.resizeAspectFill
135 |
136 | func makeNSView(context: Context) -> NSView {
137 | let view = NSView()
138 | view.wantsLayer = true
139 | view.layer?.contentsGravity = gravity
140 | return view
141 | }
142 |
143 | func updateNSView(_ uiView: NSView, context: Context) {
144 | uiView.layer?.contents = image
145 | }
146 | }
147 |
148 | #endif
149 |
--------------------------------------------------------------------------------
/app/get_pretrained_mlx_model.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # For licensing see accompanying LICENSE_MODEL file.
4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
5 | #
6 | set -e
7 |
8 | # Help function
9 | show_help() {
10 | local is_error=${1:-true} # Default to error mode if no argument provided
11 |
12 | echo "Usage: $0 --model --dest "
13 | echo
14 | echo "Required arguments:"
15 | echo " --model Size of the model to download"
16 | echo " --dest Directory where the model will be downloaded"
17 | echo
18 | echo "Available model sizes:"
19 | echo " 0.5b - 0.5B parameter model (FP16)"
20 | echo " 1.5b - 1.5B parameter model (INT8)"
21 | echo " 7b - 7B parameter model (INT4)"
22 | echo
23 | echo "Options:"
24 | echo " --help Show help message"
25 |
26 | # Exit with success (0) for help flag, error (1) for usage errors
27 | if [ "$is_error" = "false" ]; then
28 | exit 0
29 | else
30 | exit 1
31 | fi
32 | }
33 |
34 | # Parse command line arguments
35 | while [[ "$#" -gt 0 ]]; do
36 | case $1 in
37 | --model) model_size="$2"; shift ;;
38 | --dest) dest_dir="$2"; shift ;;
39 | --help) show_help false ;; # Explicit help request
40 | *) echo -e "Unknown parameter: $1\n"; show_help true ;; # Error case
41 | esac
42 | shift
43 | done
44 |
45 | # Validate required parameters
46 | if [ -z "$model_size" ]; then
47 | echo -e "Error: --model parameter is required\n"
48 | show_help true
49 | fi
50 |
51 | if [ -z "$dest_dir" ]; then
52 | echo -e "Error: --dest parameter is required\n"
53 | show_help true
54 | fi
55 |
56 | # Map model size to full model name
57 | case "$model_size" in
58 | "0.5b") model="llava-fastvithd_0.5b_stage3_llm.fp16" ;;
59 | "1.5b") model="llava-fastvithd_1.5b_stage3_llm.int8" ;;
60 | "7b") model="llava-fastvithd_7b_stage3_llm.int4" ;;
61 | *)
62 | echo -e "Error: Invalid model size '$model_size'\n"
63 | show_help true
64 | ;;
65 | esac
66 |
67 | cleanup() {
68 | rm -rf "$tmp_dir"
69 | }
70 |
71 | download_model() {
72 | # Download directory
73 | tmp_dir=$(mktemp -d)
74 |
75 | # Model paths
76 | base_url="https://ml-site.cdn-apple.com/datasets/fastvlm"
77 |
78 | # Create destination directory if it doesn't exist
79 | if [ ! -d "$dest_dir" ]; then
80 | echo "Creating destination directory: $dest_dir"
81 | mkdir -p "$dest_dir"
82 | elif [ "$(ls -A "$dest_dir")" ]; then
83 | echo -e "Destination directory '$dest_dir' exists and is not empty.\n"
84 | read -p "Do you want to clear it and continue? [y/N]: " confirm
85 | if [[ ! "$confirm" =~ ^[Yy]$ ]]; then
86 | echo -e "\nStopping."
87 | exit 1
88 | fi
89 | echo -e "\nClearing existing contents in '$dest_dir'"
90 | rm -rf "${dest_dir:?}"/*
91 | fi
92 |
93 | # Create temp variables
94 | tmp_zip_file="${tmp_dir}/${model}.zip"
95 | tmp_extract_dir="${tmp_dir}/${model}"
96 |
97 | # Create temp extract directory
98 | mkdir -p "$tmp_extract_dir"
99 |
100 | # Download model
101 | echo -e "\nDownloading '${model}' model ...\n"
102 | wget -q --progress=bar:noscroll --show-progress -O "$tmp_zip_file" "$base_url/$model.zip"
103 |
104 | # Unzip model
105 | echo -e "\nUnzipping model..."
106 | unzip -q "$tmp_zip_file" -d "$tmp_extract_dir"
107 |
108 | # Copy model files to destination directory
109 | echo -e "\nCopying model files to destination directory..."
110 | cp -r "$tmp_extract_dir/$model"/* "$dest_dir"
111 |
112 | # Verify destination directory exists and is not empty
113 | if [ ! -d "$dest_dir" ] || [ -z "$(ls -A "$dest_dir")" ]; then
114 | echo -e "\nModel extraction failed. Destination directory '$dest_dir' is missing or empty."
115 | exit 1
116 | fi
117 |
118 | echo -e "\nModel downloaded and extracted to '$dest_dir'"
119 | }
120 |
121 | # Cleanup download directory on exit
122 | trap cleanup EXIT INT TERM
123 |
124 | # Download models
125 | download_model
--------------------------------------------------------------------------------
/docs/acc_vs_latency_qwen-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/acc_vs_latency_qwen-2.png
--------------------------------------------------------------------------------
/docs/fastvlm-counting.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-counting.gif
--------------------------------------------------------------------------------
/docs/fastvlm-emoji.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-emoji.gif
--------------------------------------------------------------------------------
/docs/fastvlm-flexible_prompts.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-flexible_prompts.png
--------------------------------------------------------------------------------
/docs/fastvlm-handwriting.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/docs/fastvlm-handwriting.gif
--------------------------------------------------------------------------------
/get_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # For licensing see accompanying LICENSE_MODEL file.
4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
5 | #
6 |
7 | mkdir -p checkpoints
8 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage2.zip -P checkpoints
9 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_0.5b_stage3.zip -P checkpoints
10 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage2.zip -P checkpoints
11 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_1.5b_stage3.zip -P checkpoints
12 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage2.zip -P checkpoints
13 | wget https://ml-site.cdn-apple.com/datasets/fastvlm/llava-fastvithd_7b_stage3.zip -P checkpoints
14 |
15 | # Extract models
16 | cd checkpoints
17 | unzip -qq llava-fastvithd_0.5b_stage2.zip
18 | unzip -qq llava-fastvithd_0.5b_stage3.zip
19 | unzip -qq llava-fastvithd_1.5b_stage2.zip
20 | unzip -qq llava-fastvithd_1.5b_stage3.zip
21 | unzip -qq llava-fastvithd_7b_stage2.zip
22 | unzip -qq llava-fastvithd_7b_stage3.zip
23 |
24 | # Clean up
25 | rm llava-fastvithd_0.5b_stage2.zip
26 | rm llava-fastvithd_0.5b_stage3.zip
27 | rm llava-fastvithd_1.5b_stage2.zip
28 | rm llava-fastvithd_1.5b_stage3.zip
29 | rm llava-fastvithd_7b_stage2.zip
30 | rm llava-fastvithd_7b_stage3.zip
31 | cd -
32 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM, LlavaQwen2ForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | from PIL import Image
3 | PIL.Image.MAX_IMAGE_PIXELS=500000000
4 | from io import BytesIO
5 | import base64
6 | import torch
7 | import math
8 | import ast
9 |
10 | from transformers import StoppingCriteria
11 | from llava.constants import IMAGE_TOKEN_INDEX
12 |
13 |
14 | def select_best_resolution(original_size, possible_resolutions):
15 | """
16 | Selects the best resolution from a list of possible resolutions based on the original size.
17 |
18 | Args:
19 | original_size (tuple): The original size of the image in the format (width, height).
20 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
21 |
22 | Returns:
23 | tuple: The best fit resolution in the format (width, height).
24 | """
25 | original_width, original_height = original_size
26 | best_fit = None
27 | max_effective_resolution = 0
28 | min_wasted_resolution = float('inf')
29 |
30 | for width, height in possible_resolutions:
31 | scale = min(width / original_width, height / original_height)
32 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
33 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
34 | wasted_resolution = (width * height) - effective_resolution
35 |
36 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
37 | max_effective_resolution = effective_resolution
38 | min_wasted_resolution = wasted_resolution
39 | best_fit = (width, height)
40 |
41 | return best_fit
42 |
43 |
44 | def resize_and_pad_image(image, target_resolution):
45 | """
46 | Resize and pad an image to a target resolution while maintaining aspect ratio.
47 |
48 | Args:
49 | image (PIL.Image.Image): The input image.
50 | target_resolution (tuple): The target resolution (width, height) of the image.
51 |
52 | Returns:
53 | PIL.Image.Image: The resized and padded image.
54 | """
55 | original_width, original_height = image.size
56 | target_width, target_height = target_resolution
57 |
58 | scale_w = target_width / original_width
59 | scale_h = target_height / original_height
60 |
61 | if scale_w < scale_h:
62 | new_width = target_width
63 | new_height = min(math.ceil(original_height * scale_w), target_height)
64 | else:
65 | new_height = target_height
66 | new_width = min(math.ceil(original_width * scale_h), target_width)
67 |
68 | # Resize the image
69 | resized_image = image.resize((new_width, new_height))
70 |
71 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
72 | paste_x = (target_width - new_width) // 2
73 | paste_y = (target_height - new_height) // 2
74 | new_image.paste(resized_image, (paste_x, paste_y))
75 |
76 | return new_image
77 |
78 |
79 | def divide_to_patches(image, patch_size):
80 | """
81 | Divides an image into patches of a specified size.
82 |
83 | Args:
84 | image (PIL.Image.Image): The input image.
85 | patch_size (int): The size of each patch.
86 |
87 | Returns:
88 | list: A list of PIL.Image.Image objects representing the patches.
89 | """
90 | patches = []
91 | width, height = image.size
92 | for i in range(0, height, patch_size):
93 | for j in range(0, width, patch_size):
94 | box = (j, i, j + patch_size, i + patch_size)
95 | patch = image.crop(box)
96 | patches.append(patch)
97 |
98 | return patches
99 |
100 |
101 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
102 | """
103 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
104 |
105 | Args:
106 | image_size (tuple): The size of the input image in the format (width, height).
107 | grid_pinpoints (str): A string representation of a list of possible resolutions.
108 | patch_size (int): The size of each image patch.
109 |
110 | Returns:
111 | tuple: The shape of the image patch grid in the format (width, height).
112 | """
113 | if type(grid_pinpoints) is list:
114 | possible_resolutions = grid_pinpoints
115 | else:
116 | possible_resolutions = ast.literal_eval(grid_pinpoints)
117 | width, height = select_best_resolution(image_size, possible_resolutions)
118 | return width // patch_size, height // patch_size
119 |
120 |
121 | def process_anyres_image(image, processor, grid_pinpoints):
122 | """
123 | Process an image with variable resolutions.
124 |
125 | Args:
126 | image (PIL.Image.Image): The input image to be processed.
127 | processor: The image processor object.
128 | grid_pinpoints (str): A string representation of a list of possible resolutions.
129 |
130 | Returns:
131 | torch.Tensor: A tensor containing the processed image patches.
132 | """
133 | if type(grid_pinpoints) is list:
134 | possible_resolutions = grid_pinpoints
135 | else:
136 | possible_resolutions = ast.literal_eval(grid_pinpoints)
137 | best_resolution = select_best_resolution(image.size, possible_resolutions)
138 | image_padded = resize_and_pad_image(image, best_resolution)
139 |
140 | patches = divide_to_patches(image_padded, processor.crop_size['height'])
141 |
142 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
143 |
144 | image_patches = [image_original_resize] + patches
145 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
146 | for image_patch in image_patches]
147 | return torch.stack(image_patches, dim=0)
148 |
149 |
150 | def load_image_from_base64(image):
151 | return Image.open(BytesIO(base64.b64decode(image)))
152 |
153 |
154 | def expand2square(pil_img, background_color):
155 | width, height = pil_img.size
156 | if width == height:
157 | return pil_img
158 | elif width > height:
159 | result = Image.new(pil_img.mode, (width, width), background_color)
160 | result.paste(pil_img, (0, (width - height) // 2))
161 | return result
162 | else:
163 | result = Image.new(pil_img.mode, (height, height), background_color)
164 | result.paste(pil_img, ((height - width) // 2, 0))
165 | return result
166 |
167 |
168 | def process_images(images, image_processor, model_cfg):
169 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
170 | new_images = []
171 | if image_aspect_ratio == 'pad':
172 | for image in images:
173 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
174 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
175 | new_images.append(image)
176 | elif image_aspect_ratio == "anyres":
177 | for image in images:
178 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
179 | new_images.append(image)
180 | else:
181 | return image_processor(images, return_tensors='pt')['pixel_values']
182 | if all(x.shape == new_images[0].shape for x in new_images):
183 | new_images = torch.stack(new_images, dim=0)
184 | return new_images
185 |
186 |
187 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
188 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
189 |
190 | def insert_separator(X, sep):
191 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
192 |
193 | input_ids = []
194 | offset = 0
195 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
196 | offset = 1
197 | input_ids.append(prompt_chunks[0][0])
198 |
199 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
200 | input_ids.extend(x[offset:])
201 |
202 | if return_tensors is not None:
203 | if return_tensors == 'pt':
204 | return torch.tensor(input_ids, dtype=torch.long)
205 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
206 | return input_ids
207 |
208 |
209 | def get_model_name_from_path(model_path):
210 | model_path = model_path.strip("/")
211 | model_paths = model_path.split("/")
212 | if model_paths[-1].startswith('checkpoint-'):
213 | return model_paths[-2] + "_" + model_paths[-1]
214 | else:
215 | return model_paths[-1]
216 |
217 |
218 | class KeywordsStoppingCriteria(StoppingCriteria):
219 | def __init__(self, keywords, tokenizer, input_ids):
220 | self.keywords = keywords
221 | self.keyword_ids = []
222 | self.max_keyword_len = 0
223 | for keyword in keywords:
224 | cur_keyword_ids = tokenizer(keyword).input_ids
225 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
226 | cur_keyword_ids = cur_keyword_ids[1:]
227 | if len(cur_keyword_ids) > self.max_keyword_len:
228 | self.max_keyword_len = len(cur_keyword_ids)
229 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
230 | self.tokenizer = tokenizer
231 | self.start_len = input_ids.shape[1]
232 |
233 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
234 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
235 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
236 | for keyword_id in self.keyword_ids:
237 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
238 | if torch.equal(truncated_output_ids, keyword_id):
239 | return True
240 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
241 | for keyword in self.keywords:
242 | if keyword in outputs:
243 | return True
244 | return False
245 |
246 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
247 | outputs = []
248 | for i in range(output_ids.shape[0]):
249 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
250 | return all(outputs)
251 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | # try:
2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5 | from .language_model.llava_qwen import LlavaQwen2ForCausalLM, LlavaConfig
6 | # except:
7 | # pass
8 |
9 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | import torch
22 | from llava.model import *
23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 |
25 |
26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27 | kwargs = {"device_map": device_map, **kwargs}
28 |
29 | if device != "cuda":
30 | kwargs['device_map'] = {"": device}
31 |
32 | if load_8bit:
33 | kwargs['load_in_8bit'] = True
34 | elif load_4bit:
35 | kwargs['load_in_4bit'] = True
36 | kwargs['quantization_config'] = BitsAndBytesConfig(
37 | load_in_4bit=True,
38 | bnb_4bit_compute_dtype=torch.float16,
39 | bnb_4bit_use_double_quant=True,
40 | bnb_4bit_quant_type='nf4'
41 | )
42 | else:
43 | kwargs['torch_dtype'] = torch.float16
44 |
45 | if use_flash_attn:
46 | kwargs['attn_implementation'] = 'flash_attention_2'
47 |
48 | if 'llava' in model_name.lower():
49 | # Load LLaVA model
50 | if 'lora' in model_name.lower() and model_base is None:
51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.')
52 | if 'lora' in model_name.lower() and model_base is not None:
53 | from llava.model.language_model.llava_llama import LlavaConfig
54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path)
55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56 | print('Loading LLaVA from base model...')
57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59 | if model.lm_head.weight.shape[0] != token_num:
60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62 |
63 | print('Loading additional LLaVA weights...')
64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66 | else:
67 | # this is probably from HF Hub
68 | from huggingface_hub import hf_hub_download
69 |
70 | def load_from_hf(repo_id, filename, subfolder=None):
71 | cache_file = hf_hub_download(
72 | repo_id=repo_id,
73 | filename=filename,
74 | subfolder=subfolder)
75 | return torch.load(cache_file, map_location='cpu')
76 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
77 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
78 | if any(k.startswith('model.model.') for k in non_lora_trainables):
79 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
80 | model.load_state_dict(non_lora_trainables, strict=False)
81 |
82 | from peft import PeftModel
83 | print('Loading LoRA weights...')
84 | model = PeftModel.from_pretrained(model, model_path)
85 | print('Merging LoRA weights...')
86 | model = model.merge_and_unload()
87 | print('Model is loaded...')
88 | elif model_base is not None:
89 | # this may be mm projector only
90 | print('Loading LLaVA from base model...')
91 | if 'mpt' in model_name.lower():
92 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')):
93 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py'))
94 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True)
95 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
96 | model = LlavaMptForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
97 | else:
98 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
99 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
100 | # model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
101 | model = LlavaQwen2ForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
102 |
103 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
104 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
105 | model.load_state_dict(mm_projector_weights, strict=False)
106 | else:
107 | if 'mpt' in model_name.lower():
108 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
109 | model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
110 | elif 'mistral' in model_name.lower():
111 | tokenizer = AutoTokenizer.from_pretrained(model_path)
112 | model = LlavaMistralForCausalLM.from_pretrained(
113 | model_path,
114 | low_cpu_mem_usage=True,
115 | **kwargs
116 | )
117 | elif 'dclm' in model_name.lower():
118 | tokenizer = AutoTokenizer.from_pretrained(model_path)
119 | model = LlavaOpenlmForCausalLM.from_pretrained(
120 | model_path,
121 | low_cpu_mem_usage=True,
122 | **kwargs
123 | )
124 | else:
125 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
126 | # model = LlavaLlamaForCausalLM.from_pretrained(
127 | # model_path,
128 | # low_cpu_mem_usage=True,
129 | # **kwargs
130 | # )
131 | model = LlavaQwen2ForCausalLM.from_pretrained(
132 | model_path,
133 | low_cpu_mem_usage=True,
134 | **kwargs
135 | )
136 | else:
137 | # Load language model
138 | if model_base is not None:
139 | # PEFT model
140 | from peft import PeftModel
141 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
142 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
143 | print(f"Loading LoRA weights from {model_path}")
144 | model = PeftModel.from_pretrained(model, model_path)
145 | print(f"Merging weights")
146 | model = model.merge_and_unload()
147 | print('Convert to FP16...')
148 | model.to(torch.float16)
149 | else:
150 | use_fast = False
151 | if 'mpt' in model_name.lower():
152 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
153 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
154 | else:
155 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
156 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
157 |
158 | image_processor = None
159 |
160 | if 'llava' in model_name.lower():
161 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
162 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
163 | if mm_use_im_patch_token:
164 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
165 | if mm_use_im_start_end:
166 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
167 | model.resize_token_embeddings(len(tokenizer))
168 |
169 | vision_tower = model.get_vision_tower()
170 | if not vision_tower.is_loaded:
171 | vision_tower.load_model(device_map=device_map)
172 | if device_map != 'auto':
173 | vision_tower.to(device=device_map, dtype=torch.float16)
174 | image_processor = vision_tower.image_processor
175 |
176 | if hasattr(model.config, "max_sequence_length"):
177 | context_len = model.config.max_sequence_length
178 | else:
179 | context_len = 2048
180 |
181 | return tokenizer, model, image_processor, context_len
182 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(LlamaConfig):
31 | model_type = "llava_llama"
32 |
33 |
34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(LlavaLlamaModel, self).__init__(config)
39 |
40 |
41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = LlavaLlamaModel(config)
47 | self.pretraining_tp = config.pretraining_tp
48 | self.vocab_size = config.vocab_size
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | cache_position=None,
72 | ) -> Union[Tuple, CausalLMOutputWithPast]:
73 |
74 | if inputs_embeds is None:
75 | (
76 | input_ids,
77 | position_ids,
78 | attention_mask,
79 | past_key_values,
80 | inputs_embeds,
81 | labels
82 | ) = self.prepare_inputs_labels_for_multimodal(
83 | input_ids,
84 | position_ids,
85 | attention_mask,
86 | past_key_values,
87 | labels,
88 | images,
89 | image_sizes
90 | )
91 |
92 | return super().forward(
93 | input_ids=input_ids,
94 | attention_mask=attention_mask,
95 | position_ids=position_ids,
96 | past_key_values=past_key_values,
97 | inputs_embeds=inputs_embeds,
98 | labels=labels,
99 | use_cache=use_cache,
100 | output_attentions=output_attentions,
101 | output_hidden_states=output_hidden_states,
102 | return_dict=return_dict
103 | )
104 |
105 | @torch.no_grad()
106 | def generate(
107 | self,
108 | inputs: Optional[torch.Tensor] = None,
109 | images: Optional[torch.Tensor] = None,
110 | image_sizes: Optional[torch.Tensor] = None,
111 | **kwargs,
112 | ) -> Union[GenerateOutput, torch.LongTensor]:
113 | position_ids = kwargs.pop("position_ids", None)
114 | attention_mask = kwargs.pop("attention_mask", None)
115 | if "inputs_embeds" in kwargs:
116 | raise NotImplementedError("`inputs_embeds` is not supported")
117 |
118 | if images is not None:
119 | (
120 | inputs,
121 | position_ids,
122 | attention_mask,
123 | _,
124 | inputs_embeds,
125 | _
126 | ) = self.prepare_inputs_labels_for_multimodal(
127 | inputs,
128 | position_ids,
129 | attention_mask,
130 | None,
131 | None,
132 | images,
133 | image_sizes=image_sizes
134 | )
135 | else:
136 | inputs_embeds = self.get_model().embed_tokens(inputs)
137 |
138 | return super().generate(
139 | position_ids=position_ids,
140 | attention_mask=attention_mask,
141 | inputs_embeds=inputs_embeds,
142 | **kwargs
143 | )
144 |
145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146 | inputs_embeds=None, **kwargs):
147 | images = kwargs.pop("images", None)
148 | image_sizes = kwargs.pop("image_sizes", None)
149 | inputs = super().prepare_inputs_for_generation(
150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151 | )
152 | if images is not None:
153 | inputs['images'] = images
154 | if image_sizes is not None:
155 | inputs['image_sizes'] = image_sizes
156 | return inputs
157 |
158 | AutoConfig.register("llava_llama", LlavaConfig)
159 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
160 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | MistralConfig, MistralModel, MistralForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 | from transformers.generation.utils import GenerateOutput
27 |
28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29 |
30 |
31 | class LlavaMistralConfig(MistralConfig):
32 | model_type = "llava_mistral"
33 |
34 |
35 | class LlavaMistralModel(LlavaMetaModel, MistralModel):
36 | config_class = LlavaMistralConfig
37 |
38 | def __init__(self, config: MistralConfig):
39 | super(LlavaMistralModel, self).__init__(config)
40 |
41 |
42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43 | config_class = LlavaMistralConfig
44 |
45 | def __init__(self, config):
46 | super(MistralForCausalLM, self).__init__(config)
47 | self.model = LlavaMistralModel(config)
48 |
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_mistral", LlavaMistralConfig)
158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
159 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 | MptConfig, MptForCausalLM, MptModel
22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23 |
24 |
25 | class LlavaMptConfig(MptConfig):
26 | model_type = "llava_mpt"
27 |
28 |
29 | class LlavaMptModel(LlavaMetaModel, MptModel):
30 | config_class = LlavaMptConfig
31 |
32 | def __init__(self, config: MptConfig):
33 | config.hidden_size = config.d_model
34 | super(LlavaMptModel, self).__init__(config)
35 |
36 | def embed_tokens(self, x):
37 | return self.wte(x)
38 |
39 |
40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaMptConfig
42 | supports_gradient_checkpointing = True
43 |
44 | def __init__(self, config):
45 | super(MptForCausalLM, self).__init__(config)
46 |
47 | self.transformer = LlavaMptModel(config)
48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.transformer
55 |
56 | def _set_gradient_checkpointing(self, module, value=False):
57 | if isinstance(module, LlavaMptModel):
58 | module.gradient_checkpointing = value
59 |
60 | def forward(
61 | self,
62 | input_ids: Optional[torch.LongTensor] = None,
63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64 | attention_mask: Optional[torch.Tensor] = None,
65 | inputs_embeds: Optional[torch.Tensor] = None,
66 | labels: Optional[torch.Tensor] = None,
67 | use_cache: Optional[bool] = None,
68 | output_attentions: Optional[bool] = None,
69 | output_hidden_states: Optional[bool] = None,
70 | return_dict: Optional[bool] = None,
71 | images=None):
72 |
73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74 |
75 | return super().forward(
76 | input_ids,
77 | past_key_values=past_key_values,
78 | attention_mask=attention_mask,
79 | inputs_embeds=inputs_embeds,
80 | labels=labels,
81 | use_cache=use_cache,
82 | output_attentions=output_attentions,
83 | output_hidden_states=output_hidden_states,
84 | return_dict=return_dict,
85 | )
86 |
87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88 | images = kwargs.pop("images", None)
89 | _inputs = super().prepare_inputs_for_generation(
90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91 | )
92 | _inputs['images'] = images
93 | return _inputs
94 |
95 |
96 | AutoConfig.register("llava_mpt", LlavaMptConfig)
97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
98 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_qwen.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2023 Haotian Liu
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from typing import List, Optional, Tuple, Union
18 |
19 | import torch
20 | import torch.nn as nn
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, Qwen2Config, Qwen2Model, Qwen2ForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(Qwen2Config):
31 | model_type = "llava_qwen2"
32 |
33 |
34 | class LlavaQwen2Model(LlavaMetaModel, Qwen2Model):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: Qwen2Config):
38 | super(LlavaQwen2Model, self).__init__(config)
39 |
40 |
41 | class LlavaQwen2ForCausalLM(Qwen2ForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(Qwen2ForCausalLM, self).__init__(config)
46 | self.model = LlavaQwen2Model(config)
47 | # self.pretraining_tp = config.pretraining_tp
48 | self.vocab_size = config.vocab_size
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | cache_position=None,
72 | ) -> Union[Tuple, CausalLMOutputWithPast]:
73 |
74 | if inputs_embeds is None:
75 | (
76 | input_ids,
77 | position_ids,
78 | attention_mask,
79 | past_key_values,
80 | inputs_embeds,
81 | labels
82 | ) = self.prepare_inputs_labels_for_multimodal(
83 | input_ids,
84 | position_ids,
85 | attention_mask,
86 | past_key_values,
87 | labels,
88 | images,
89 | image_sizes
90 | )
91 |
92 | return super().forward(
93 | input_ids=input_ids,
94 | attention_mask=attention_mask,
95 | position_ids=position_ids,
96 | past_key_values=past_key_values,
97 | inputs_embeds=inputs_embeds,
98 | labels=labels,
99 | use_cache=use_cache,
100 | output_attentions=output_attentions,
101 | output_hidden_states=output_hidden_states,
102 | return_dict=return_dict
103 | )
104 |
105 | @torch.no_grad()
106 | def generate(
107 | self,
108 | inputs: Optional[torch.Tensor] = None,
109 | images: Optional[torch.Tensor] = None,
110 | image_sizes: Optional[torch.Tensor] = None,
111 | **kwargs,
112 | ) -> Union[GenerateOutput, torch.LongTensor]:
113 | position_ids = kwargs.pop("position_ids", None)
114 | attention_mask = kwargs.pop("attention_mask", None)
115 | if "inputs_embeds" in kwargs:
116 | raise NotImplementedError("`inputs_embeds` is not supported")
117 |
118 | if images is not None:
119 | (
120 | inputs,
121 | position_ids,
122 | attention_mask,
123 | _,
124 | inputs_embeds,
125 | _
126 | ) = self.prepare_inputs_labels_for_multimodal(
127 | inputs,
128 | position_ids,
129 | attention_mask,
130 | None,
131 | None,
132 | images,
133 | image_sizes=image_sizes
134 | )
135 | else:
136 | inputs_embeds = self.get_model().embed_tokens(inputs)
137 |
138 | return super().generate(
139 | position_ids=position_ids,
140 | attention_mask=attention_mask,
141 | inputs_embeds=inputs_embeds,
142 | **kwargs
143 | )
144 |
145 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
146 | inputs_embeds=None, **kwargs):
147 | images = kwargs.pop("images", None)
148 | image_sizes = kwargs.pop("image_sizes", None)
149 | inputs = super().prepare_inputs_for_generation(
150 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
151 | )
152 | if images is not None:
153 | inputs['images'] = images
154 | if image_sizes is not None:
155 | inputs['image_sizes'] = image_sizes
156 | return inputs
157 |
158 |
159 | AutoConfig.register("llava_qwen2", LlavaConfig)
160 | AutoModelForCausalLM.register(LlavaConfig, LlavaQwen2ForCausalLM)
161 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3 | from .mobileclip_encoder import MobileCLIPVisionTower
4 |
5 |
6 | def build_vision_tower(vision_tower_cfg, **kwargs):
7 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
8 | is_absolute_path_exists = os.path.exists(vision_tower)
9 | use_s2 = getattr(vision_tower_cfg, 's2', False)
10 |
11 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
12 | if use_s2:
13 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
14 | else:
15 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
16 | elif "mobileclip" in vision_tower.lower():
17 | return MobileCLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
18 |
19 | raise ValueError(f'Unknown vision tower: {vision_tower}')
20 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 | self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
17 | self.input_image_size = getattr(args, 'input_image_size', None)
18 |
19 | if self.tune_vision_tower:
20 | print("CLIP Vision tower is set to tunable")
21 |
22 | if not delay_load:
23 | self.load_model()
24 | elif getattr(args, 'unfreeze_mm_vision_tower', False):
25 | self.load_model()
26 | else:
27 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
28 | if self.input_image_size is not None:
29 | self.cfg_only.image_size = self.input_image_size
30 |
31 | def load_model(self, device_map=None):
32 | if self.is_loaded:
33 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
34 | return
35 |
36 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
37 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
38 | if not self.tune_vision_tower:
39 | self.vision_tower.requires_grad_(False)
40 |
41 | if self.input_image_size is not None:
42 | print("Using input image size: {}".format(self.input_image_size))
43 | self.image_processor.size['shortest_edge'] = self.input_image_size
44 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.input_image_size
45 |
46 | self.is_loaded = True
47 |
48 | def feature_select(self, image_forward_outs):
49 | image_features = image_forward_outs.hidden_states[self.select_layer]
50 | if self.select_feature == 'patch':
51 | image_features = image_features[:, 1:]
52 | elif self.select_feature == 'cls_patch':
53 | image_features = image_features
54 | else:
55 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
56 | return image_features
57 |
58 | def forward(self, images):
59 | if self.tune_vision_tower:
60 | return self.forward_images(images)
61 | else:
62 | with torch.no_grad():
63 | return self.forward_images(images)
64 |
65 | def forward_images(self, images):
66 | if type(images) is list:
67 | image_features = []
68 | for image in images:
69 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
70 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
71 | image_features.append(image_feature)
72 | else:
73 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
74 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
75 |
76 | return image_features
77 |
78 | @property
79 | def dummy_feature(self):
80 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
81 |
82 | @property
83 | def dtype(self):
84 | return self.vision_tower.dtype
85 |
86 | @property
87 | def device(self):
88 | return self.vision_tower.device
89 |
90 | @property
91 | def config(self):
92 | if self.is_loaded:
93 | return self.vision_tower.config
94 | else:
95 | return self.cfg_only
96 |
97 | @property
98 | def hidden_size(self):
99 | return self.config.hidden_size
100 |
101 | @property
102 | def num_patches_per_side(self):
103 | return self.config.image_size // self.config.patch_size
104 |
105 | @property
106 | def num_patches(self):
107 | return (self.config.image_size // self.config.patch_size) ** 2
108 |
109 |
110 |
111 | class CLIPVisionTowerS2(CLIPVisionTower):
112 | def __init__(self, vision_tower, args, delay_load=False):
113 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
114 | self.s2_scales = list(map(int, self.s2_scales.split(',')))
115 | self.s2_scales.sort()
116 | self.s2_split_size = self.s2_scales[0]
117 | self.s2_image_size = self.s2_scales[-1]
118 |
119 | super().__init__(vision_tower, args, delay_load)
120 |
121 | try:
122 | from s2wrapper import forward as multiscale_forward
123 | except ImportError:
124 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
125 | self.multiscale_forward = multiscale_forward
126 |
127 | # change resize/crop size in preprocessing to the largest image size in s2_scale
128 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
129 | self.image_processor.size['shortest_edge'] = self.s2_image_size
130 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
131 |
132 | def load_model(self, device_map=None):
133 | if self.is_loaded:
134 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
135 | return
136 |
137 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
138 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
139 | self.vision_tower.requires_grad_(False)
140 |
141 | self.image_processor.size['shortest_edge'] = self.s2_image_size
142 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
143 |
144 | self.is_loaded = True
145 |
146 | @torch.no_grad()
147 | def forward_feature(self, images):
148 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
149 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
150 | return image_features
151 |
152 | @torch.no_grad()
153 | def forward(self, images):
154 | if type(images) is list:
155 | image_features = []
156 | for image in images:
157 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
158 | image_features.append(image_feature)
159 | else:
160 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
161 |
162 | return image_features
163 |
164 | @property
165 | def hidden_size(self):
166 | return self.config.hidden_size * len(self.s2_scales)
167 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/mobileclip/__init__.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 | import json
7 | from typing import Any
8 |
9 | import torch.nn as nn
10 | from timm.models import create_model
11 |
12 | from .mci import GlobalPool2D
13 |
14 |
15 | def load_model_config(
16 | model_name: str,
17 | ) -> Any:
18 | # Strip suffixes to model name
19 | model_name = "_".join(model_name.split("_")[0:2])
20 |
21 | # Config files
22 | root_dir = os.path.dirname(os.path.abspath(__file__))
23 | configs_dir = os.path.join(root_dir, "configs")
24 | model_cfg_file = os.path.join(configs_dir, model_name + ".json")
25 |
26 | # Get config from yaml file
27 | if not os.path.exists(model_cfg_file):
28 | raise ValueError(f"Unsupported model name: {model_name}")
29 | model_cfg = json.load(open(model_cfg_file, "r"))
30 |
31 | return model_cfg
32 |
33 |
34 | class MCi(nn.Module):
35 | """
36 | This class implements `MCi Models `_
37 | """
38 |
39 | def __init__(self, model_name: str, *args, **kwargs) -> None:
40 | super().__init__()
41 | self.projection_dim = None
42 | if "projection_dim" in kwargs:
43 | self.projection_dim = kwargs.get("projection_dim")
44 |
45 | # Create model
46 | self.model = create_model(model_name, projection_dim=self.projection_dim)
47 |
48 | # Build out projection head.
49 | if self.projection_dim is not None:
50 | if hasattr(self.model, "head"):
51 | self.model.head = MCi._update_image_classifier(
52 | image_classifier=self.model.head, projection_dim=self.projection_dim
53 | )
54 |
55 | def forward(self, x: Any, *args, **kwargs) -> Any:
56 | """A forward function of the model."""
57 | x = self.model(x, *args, **kwargs)
58 | return x
59 |
60 | @staticmethod
61 | def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
62 | """Return the input feature dimension to the image classification head."""
63 | in_features = None
64 | if isinstance(image_classifier, nn.Sequential):
65 | # Classifier that uses nn.Sequential usually has global pooling and
66 | # multiple linear layers. Find the first linear layer and get its
67 | # in_features
68 | for layer in image_classifier:
69 | if isinstance(layer, nn.Linear):
70 | in_features = layer.in_features
71 | break
72 | elif isinstance(image_classifier, nn.Linear):
73 | in_features = image_classifier.in_features
74 |
75 | if in_features is None:
76 | raise NotImplementedError(
77 | f"Cannot get input feature dimension of {image_classifier}."
78 | )
79 | return in_features
80 |
81 | @staticmethod
82 | def _update_image_classifier(
83 | image_classifier: nn.Module, projection_dim: int, *args, **kwargs
84 | ) -> nn.Module:
85 | in_features = MCi._get_in_feature_dimension(image_classifier)
86 | new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
87 | return new_img_classifier
88 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/mobileclip/configs/mobileclip_l.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "image_cfg": {
4 | "image_size": 1024,
5 | "model_name": "fastvithd",
6 | "embed_dim": 3072,
7 | "patch_size": 64
8 | },
9 | "text_cfg": {
10 | "context_length": 77,
11 | "vocab_size": 49408,
12 | "dim": 768,
13 | "ffn_multiplier_per_layer": 4.0,
14 | "n_heads_per_layer": 12,
15 | "n_transformer_layers": 12,
16 | "norm_layer": "layer_norm_fp32",
17 | "causal_masking": false,
18 | "model_name": "base"
19 | }
20 | }
21 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/mobileclip_encoder.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | #
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from transformers import CLIPImageProcessor
10 | import llava.model.multimodal_encoder.mobileclip as mobileclip
11 |
12 |
13 | class MobileCLIPVisionTower(nn.Module):
14 | def __init__(self, vision_tower, args, delay_load=False):
15 | super().__init__()
16 |
17 | self.is_loaded = False
18 | self.vision_tower_name = vision_tower
19 | self.tune_vision_tower = getattr(args, 'unfreeze_mm_vision_tower', False)
20 | self.input_image_size = int(vision_tower.split("_")[-1])
21 |
22 | # Delay load is disabled for now
23 | if not delay_load:
24 | self.load_model()
25 | elif getattr(args, 'unfreeze_mm_vision_tower', False):
26 | self.load_model()
27 | else:
28 | model_cfg = mobileclip.load_model_config(self.vision_tower_name)
29 | self.cfg_only = model_cfg
30 |
31 | def load_model(self, device_map=None):
32 | if self.is_loaded:
33 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
34 | return
35 |
36 | # Load model config
37 | model_cfg = mobileclip.load_model_config(self.vision_tower_name)
38 |
39 | # Override default image resolution
40 | model_cfg["image_cfg"]["image_size"] = self.input_image_size
41 |
42 | self.cfg_only = model_cfg
43 |
44 | # Build HF CLIPImageProcessor with MobileCLIP parameters
45 | self.image_processor = CLIPImageProcessor(crop_size={"height": model_cfg["image_cfg"]["image_size"],
46 | "width": model_cfg["image_cfg"]["image_size"]},
47 | image_mean=[0.0, 0.0, 0.0],
48 | image_std=[1.0, 1.0, 1.0],
49 | size={"shortest_edge": model_cfg["image_cfg"]["image_size"]})
50 |
51 | # Instantiate the image encoder
52 | self.vision_tower = mobileclip.MCi(model_name=model_cfg["image_cfg"]["model_name"],
53 | projection_dim=model_cfg["embed_dim"])
54 |
55 | if not self.tune_vision_tower:
56 | self.vision_tower.requires_grad_(False)
57 |
58 | self.is_loaded = True
59 |
60 | def feature_select(self, image_forward_outs):
61 | # Features from penultimate layer
62 | image_features = image_forward_outs["image_embeddings"]
63 |
64 | # Reshape 4D tensor to 3D
65 | B, C, H, W = image_features.shape
66 | image_features = image_features.reshape(B, C, H*W)
67 | image_features = image_features.transpose(1, 2)
68 | return image_features
69 |
70 | def forward(self, images):
71 | if self.tune_vision_tower:
72 | return self.forward_images(images)
73 | else:
74 | with torch.no_grad():
75 | return self.forward_images(images)
76 |
77 | def forward_images(self, images):
78 | if type(images) is list:
79 | image_features = []
80 | for image in images:
81 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
82 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
83 | image_features.append(image_feature)
84 | else:
85 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
86 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
87 |
88 | return image_features
89 |
90 | @property
91 | def dummy_feature(self):
92 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
93 |
94 | @property
95 | def dtype(self):
96 | return next(self.vision_tower.parameters()).dtype
97 |
98 | @property
99 | def device(self):
100 | return next(self.vision_tower.parameters()).device
101 |
102 | @property
103 | def config(self):
104 | return self.cfg_only
105 |
106 | @property
107 | def hidden_size(self):
108 | return self.config["image_cfg"]["embed_dim"]
109 |
110 | @property
111 | def num_patches_per_side(self):
112 | return self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]
113 |
114 | @property
115 | def num_patches(self):
116 | return (self.config["image_cfg"]["image_size"] // self.config["image_cfg"]["patch_size"]) ** 2
117 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import re
3 |
4 |
5 | class IdentityMap(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 |
9 | def forward(self, x, *args, **kwargs):
10 | return x
11 |
12 | @property
13 | def config(self):
14 | return {"mm_projector_type": 'identity'}
15 |
16 |
17 | def build_vision_projector(config, delay_load=False, **kwargs):
18 | projector_type = getattr(config, 'mm_projector_type', 'linear')
19 |
20 | if projector_type == 'linear':
21 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
22 |
23 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
24 | if mlp_gelu_match:
25 | mlp_depth = int(mlp_gelu_match.group(1))
26 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
27 | for _ in range(1, mlp_depth):
28 | modules.append(nn.GELU())
29 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
30 | return nn.Sequential(*modules)
31 |
32 | if projector_type == 'identity':
33 | return IdentityMap()
34 |
35 | raise ValueError(f'Unknown projector type: {projector_type}')
36 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if "llama-2" in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "mistral" in model_name.lower():
37 | conv_mode = "mistral_instruct"
38 | elif "v1.6-34b" in model_name.lower():
39 | conv_mode = "chatml_direct"
40 | elif "v1" in model_name.lower():
41 | conv_mode = "llava_v1"
42 | elif "mpt" in model_name.lower():
43 | conv_mode = "mpt"
44 | else:
45 | conv_mode = "llava_v0"
46 |
47 | if args.conv_mode is not None and conv_mode != args.conv_mode:
48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
49 | else:
50 | args.conv_mode = conv_mode
51 |
52 | conv = conv_templates[args.conv_mode].copy()
53 | if "mpt" in model_name.lower():
54 | roles = ('user', 'assistant')
55 | else:
56 | roles = conv.roles
57 |
58 | image = load_image(args.image_file)
59 | image_size = image.size
60 | # Similar operation in model_worker.py
61 | image_tensor = process_images([image], image_processor, model.config)
62 | if type(image_tensor) is list:
63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
64 | else:
65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
66 |
67 | while True:
68 | try:
69 | inp = input(f"{roles[0]}: ")
70 | except EOFError:
71 | inp = ""
72 | if not inp:
73 | print("exit...")
74 | break
75 |
76 | print(f"{roles[1]}: ", end="")
77 |
78 | if image is not None:
79 | # first message
80 | if model.config.mm_use_im_start_end:
81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
82 | else:
83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
84 | image = None
85 |
86 | conv.append_message(conv.roles[0], inp)
87 | conv.append_message(conv.roles[1], None)
88 | prompt = conv.get_prompt()
89 |
90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
92 | keywords = [stop_str]
93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
94 |
95 | with torch.inference_mode():
96 | output_ids = model.generate(
97 | input_ids,
98 | images=image_tensor,
99 | image_sizes=[image_size],
100 | do_sample=True if args.temperature > 0 else False,
101 | temperature=args.temperature,
102 | max_new_tokens=args.max_new_tokens,
103 | streamer=streamer,
104 | use_cache=True)
105 |
106 | outputs = tokenizer.decode(output_ids[0]).strip()
107 | conv.messages[-1][-1] = outputs
108 |
109 | if args.debug:
110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
111 |
112 |
113 | if __name__ == "__main__":
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116 | parser.add_argument("--model-base", type=str, default=None)
117 | parser.add_argument("--image-file", type=str, required=True)
118 | parser.add_argument("--device", type=str, default="cuda")
119 | parser.add_argument("--conv-mode", type=str, default=None)
120 | parser.add_argument("--temperature", type=float, default=0.2)
121 | parser.add_argument("--max-new-tokens", type=int, default=512)
122 | parser.add_argument("--load-8bit", action="store_true")
123 | parser.add_argument("--load-4bit", action="store_true")
124 | parser.add_argument("--debug", action="store_true")
125 | args = parser.parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/llava/serve/controller.py:
--------------------------------------------------------------------------------
1 | """
2 | A controller manages distributed workers.
3 | It sends worker addresses to clients.
4 | """
5 | import argparse
6 | import asyncio
7 | import dataclasses
8 | from enum import Enum, auto
9 | import json
10 | import logging
11 | import time
12 | from typing import List, Union
13 | import threading
14 |
15 | from fastapi import FastAPI, Request
16 | from fastapi.responses import StreamingResponse
17 | import numpy as np
18 | import requests
19 | import uvicorn
20 |
21 | from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
22 | from llava.utils import build_logger, server_error_msg
23 |
24 |
25 | logger = build_logger("controller", "controller.log")
26 |
27 |
28 | class DispatchMethod(Enum):
29 | LOTTERY = auto()
30 | SHORTEST_QUEUE = auto()
31 |
32 | @classmethod
33 | def from_str(cls, name):
34 | if name == "lottery":
35 | return cls.LOTTERY
36 | elif name == "shortest_queue":
37 | return cls.SHORTEST_QUEUE
38 | else:
39 | raise ValueError(f"Invalid dispatch method")
40 |
41 |
42 | @dataclasses.dataclass
43 | class WorkerInfo:
44 | model_names: List[str]
45 | speed: int
46 | queue_length: int
47 | check_heart_beat: bool
48 | last_heart_beat: str
49 |
50 |
51 | def heart_beat_controller(controller):
52 | while True:
53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
54 | controller.remove_stable_workers_by_expiration()
55 |
56 |
57 | class Controller:
58 | def __init__(self, dispatch_method: str):
59 | # Dict[str -> WorkerInfo]
60 | self.worker_info = {}
61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method)
62 |
63 | self.heart_beat_thread = threading.Thread(
64 | target=heart_beat_controller, args=(self,), daemon=True)
65 | self.heart_beat_thread.start()
66 |
67 | logger.info("Init controller")
68 |
69 | def register_worker(self, worker_name: str, check_heart_beat: bool,
70 | worker_status: dict):
71 | if worker_name not in self.worker_info:
72 | logger.info(f"Register a new worker: {worker_name}")
73 | else:
74 | logger.info(f"Register an existing worker: {worker_name}")
75 |
76 | if not worker_status:
77 | worker_status = self.get_worker_status(worker_name)
78 | if not worker_status:
79 | return False
80 |
81 | self.worker_info[worker_name] = WorkerInfo(
82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
83 | check_heart_beat, time.time())
84 |
85 | logger.info(f"Register done: {worker_name}, {worker_status}")
86 | return True
87 |
88 | def get_worker_status(self, worker_name: str):
89 | try:
90 | r = requests.post(worker_name + "/worker_get_status", timeout=5)
91 | except requests.exceptions.RequestException as e:
92 | logger.error(f"Get status fails: {worker_name}, {e}")
93 | return None
94 |
95 | if r.status_code != 200:
96 | logger.error(f"Get status fails: {worker_name}, {r}")
97 | return None
98 |
99 | return r.json()
100 |
101 | def remove_worker(self, worker_name: str):
102 | del self.worker_info[worker_name]
103 |
104 | def refresh_all_workers(self):
105 | old_info = dict(self.worker_info)
106 | self.worker_info = {}
107 |
108 | for w_name, w_info in old_info.items():
109 | if not self.register_worker(w_name, w_info.check_heart_beat, None):
110 | logger.info(f"Remove stale worker: {w_name}")
111 |
112 | def list_models(self):
113 | model_names = set()
114 |
115 | for w_name, w_info in self.worker_info.items():
116 | model_names.update(w_info.model_names)
117 |
118 | return list(model_names)
119 |
120 | def get_worker_address(self, model_name: str):
121 | if self.dispatch_method == DispatchMethod.LOTTERY:
122 | worker_names = []
123 | worker_speeds = []
124 | for w_name, w_info in self.worker_info.items():
125 | if model_name in w_info.model_names:
126 | worker_names.append(w_name)
127 | worker_speeds.append(w_info.speed)
128 | worker_speeds = np.array(worker_speeds, dtype=np.float32)
129 | norm = np.sum(worker_speeds)
130 | if norm < 1e-4:
131 | return ""
132 | worker_speeds = worker_speeds / norm
133 | if True: # Directly return address
134 | pt = np.random.choice(np.arange(len(worker_names)),
135 | p=worker_speeds)
136 | worker_name = worker_names[pt]
137 | return worker_name
138 |
139 | # Check status before returning
140 | while True:
141 | pt = np.random.choice(np.arange(len(worker_names)),
142 | p=worker_speeds)
143 | worker_name = worker_names[pt]
144 |
145 | if self.get_worker_status(worker_name):
146 | break
147 | else:
148 | self.remove_worker(worker_name)
149 | worker_speeds[pt] = 0
150 | norm = np.sum(worker_speeds)
151 | if norm < 1e-4:
152 | return ""
153 | worker_speeds = worker_speeds / norm
154 | continue
155 | return worker_name
156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
157 | worker_names = []
158 | worker_qlen = []
159 | for w_name, w_info in self.worker_info.items():
160 | if model_name in w_info.model_names:
161 | worker_names.append(w_name)
162 | worker_qlen.append(w_info.queue_length / w_info.speed)
163 | if len(worker_names) == 0:
164 | return ""
165 | min_index = np.argmin(worker_qlen)
166 | w_name = worker_names[min_index]
167 | self.worker_info[w_name].queue_length += 1
168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
169 | return w_name
170 | else:
171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
172 |
173 | def receive_heart_beat(self, worker_name: str, queue_length: int):
174 | if worker_name not in self.worker_info:
175 | logger.info(f"Receive unknown heart beat. {worker_name}")
176 | return False
177 |
178 | self.worker_info[worker_name].queue_length = queue_length
179 | self.worker_info[worker_name].last_heart_beat = time.time()
180 | logger.info(f"Receive heart beat. {worker_name}")
181 | return True
182 |
183 | def remove_stable_workers_by_expiration(self):
184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
185 | to_delete = []
186 | for worker_name, w_info in self.worker_info.items():
187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire:
188 | to_delete.append(worker_name)
189 |
190 | for worker_name in to_delete:
191 | self.remove_worker(worker_name)
192 |
193 | def worker_api_generate_stream(self, params):
194 | worker_addr = self.get_worker_address(params["model"])
195 | if not worker_addr:
196 | logger.info(f"no worker: {params['model']}")
197 | ret = {
198 | "text": server_error_msg,
199 | "error_code": 2,
200 | }
201 | yield json.dumps(ret).encode() + b"\0"
202 |
203 | try:
204 | response = requests.post(worker_addr + "/worker_generate_stream",
205 | json=params, stream=True, timeout=5)
206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
207 | if chunk:
208 | yield chunk + b"\0"
209 | except requests.exceptions.RequestException as e:
210 | logger.info(f"worker timeout: {worker_addr}")
211 | ret = {
212 | "text": server_error_msg,
213 | "error_code": 3,
214 | }
215 | yield json.dumps(ret).encode() + b"\0"
216 |
217 | # Let the controller act as a worker to achieve hierarchical
218 | # management. This can be used to connect isolated sub networks.
219 |
220 | def worker_api_get_status(self):
221 | model_names = set()
222 | speed = 0
223 | queue_length = 0
224 |
225 | for w_name in self.worker_info:
226 | worker_status = self.get_worker_status(w_name)
227 | if worker_status is not None:
228 | model_names.update(worker_status["model_names"])
229 | speed += worker_status["speed"]
230 | queue_length += worker_status["queue_length"]
231 |
232 | return {
233 | "model_names": list(model_names),
234 | "speed": speed,
235 | "queue_length": queue_length,
236 | }
237 |
238 |
239 | app = FastAPI()
240 |
241 |
242 | @app.post("/register_worker")
243 | async def register_worker(request: Request):
244 | data = await request.json()
245 | controller.register_worker(
246 | data["worker_name"], data["check_heart_beat"],
247 | data.get("worker_status", None))
248 |
249 |
250 | @app.post("/refresh_all_workers")
251 | async def refresh_all_workers():
252 | models = controller.refresh_all_workers()
253 |
254 |
255 | @app.post("/list_models")
256 | async def list_models():
257 | models = controller.list_models()
258 | return {"models": models}
259 |
260 |
261 | @app.post("/get_worker_address")
262 | async def get_worker_address(request: Request):
263 | data = await request.json()
264 | addr = controller.get_worker_address(data["model"])
265 | return {"address": addr}
266 |
267 |
268 | @app.post("/receive_heart_beat")
269 | async def receive_heart_beat(request: Request):
270 | data = await request.json()
271 | exist = controller.receive_heart_beat(
272 | data["worker_name"], data["queue_length"])
273 | return {"exist": exist}
274 |
275 |
276 | @app.post("/worker_generate_stream")
277 | async def worker_api_generate_stream(request: Request):
278 | params = await request.json()
279 | generator = controller.worker_api_generate_stream(params)
280 | return StreamingResponse(generator)
281 |
282 |
283 | @app.post("/worker_get_status")
284 | async def worker_api_get_status(request: Request):
285 | return controller.worker_api_get_status()
286 |
287 |
288 | if __name__ == "__main__":
289 | parser = argparse.ArgumentParser()
290 | parser.add_argument("--host", type=str, default="localhost")
291 | parser.add_argument("--port", type=int, default=21001)
292 | parser.add_argument("--dispatch-method", type=str, choices=[
293 | "lottery", "shortest_queue"], default="shortest_queue")
294 | args = parser.parse_args()
295 | logger.info(f"args: {args}")
296 |
297 | controller = Controller(args.dispatch_method)
298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
299 |
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-fastvlm/592b4add3c1c8a518e77d95dc6248e76c1dd591f/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/model_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker executes the model.
3 | """
4 | import argparse
5 | import asyncio
6 | import json
7 | import time
8 | import threading
9 | import uuid
10 |
11 | from fastapi import FastAPI, Request, BackgroundTasks
12 | from fastapi.responses import StreamingResponse
13 | import requests
14 | import torch
15 | import uvicorn
16 | from functools import partial
17 |
18 | from llava.constants import WORKER_HEART_BEAT_INTERVAL
19 | from llava.utils import (build_logger, server_error_msg,
20 | pretty_print_semaphore)
21 | from llava.model.builder import load_pretrained_model
22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
23 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 | from transformers import TextIteratorStreamer
25 | from threading import Thread
26 |
27 |
28 | GB = 1 << 30
29 |
30 | worker_id = str(uuid.uuid4())[:6]
31 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
32 | global_counter = 0
33 |
34 | model_semaphore = None
35 |
36 |
37 | def heart_beat_worker(controller):
38 |
39 | while True:
40 | time.sleep(WORKER_HEART_BEAT_INTERVAL)
41 | controller.send_heart_beat()
42 |
43 |
44 | class ModelWorker:
45 | def __init__(self, controller_addr, worker_addr,
46 | worker_id, no_register,
47 | model_path, model_base, model_name,
48 | load_8bit, load_4bit, device, use_flash_attn=False):
49 | self.controller_addr = controller_addr
50 | self.worker_addr = worker_addr
51 | self.worker_id = worker_id
52 | if model_path.endswith("/"):
53 | model_path = model_path[:-1]
54 | if model_name is None:
55 | model_paths = model_path.split("/")
56 | if model_paths[-1].startswith('checkpoint-'):
57 | self.model_name = model_paths[-2] + "_" + model_paths[-1]
58 | else:
59 | self.model_name = model_paths[-1]
60 | else:
61 | self.model_name = model_name
62 |
63 | self.device = device
64 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
65 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
66 | model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device, use_flash_attn=use_flash_attn)
67 | self.is_multimodal = 'llava' in self.model_name.lower()
68 |
69 | if not no_register:
70 | self.register_to_controller()
71 | self.heart_beat_thread = threading.Thread(
72 | target=heart_beat_worker, args=(self,), daemon=True)
73 | self.heart_beat_thread.start()
74 |
75 | def register_to_controller(self):
76 | logger.info("Register to controller")
77 |
78 | url = self.controller_addr + "/register_worker"
79 | data = {
80 | "worker_name": self.worker_addr,
81 | "check_heart_beat": True,
82 | "worker_status": self.get_status()
83 | }
84 | r = requests.post(url, json=data)
85 | assert r.status_code == 200
86 |
87 | def send_heart_beat(self):
88 | logger.info(f"Send heart beat. Models: {[self.model_name]}. "
89 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
90 | f"global_counter: {global_counter}")
91 |
92 | url = self.controller_addr + "/receive_heart_beat"
93 |
94 | while True:
95 | try:
96 | ret = requests.post(url, json={
97 | "worker_name": self.worker_addr,
98 | "queue_length": self.get_queue_length()}, timeout=5)
99 | exist = ret.json()["exist"]
100 | break
101 | except requests.exceptions.RequestException as e:
102 | logger.error(f"heart beat error: {e}")
103 | time.sleep(5)
104 |
105 | if not exist:
106 | self.register_to_controller()
107 |
108 | def get_queue_length(self):
109 | if model_semaphore is None:
110 | return 0
111 | else:
112 | return args.limit_model_concurrency - model_semaphore._value + (len(
113 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
114 |
115 | def get_status(self):
116 | return {
117 | "model_names": [self.model_name],
118 | "speed": 1,
119 | "queue_length": self.get_queue_length(),
120 | }
121 |
122 | @torch.inference_mode()
123 | def generate_stream(self, params):
124 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
125 |
126 | prompt = params["prompt"]
127 | ori_prompt = prompt
128 | images = params.get("images", None)
129 | num_image_tokens = 0
130 | if images is not None and len(images) > 0 and self.is_multimodal:
131 | if len(images) > 0:
132 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
133 | raise ValueError("Number of images does not match number of tokens in prompt")
134 |
135 | images = [load_image_from_base64(image) for image in images]
136 | image_sizes = [image.size for image in images]
137 | images = process_images(images, image_processor, model.config)
138 |
139 | if type(images) is list:
140 | images = [image.to(self.model.device, dtype=torch.float16) for image in images]
141 | else:
142 | images = images.to(self.model.device, dtype=torch.float16)
143 |
144 | replace_token = DEFAULT_IMAGE_TOKEN
145 | if getattr(self.model.config, 'mm_use_im_start_end', False):
146 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
147 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
148 |
149 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
150 | else:
151 | images = None
152 | image_sizes = None
153 | image_args = {"images": images, "image_sizes": image_sizes}
154 | else:
155 | images = None
156 | image_args = {}
157 |
158 | temperature = float(params.get("temperature", 1.0))
159 | top_p = float(params.get("top_p", 1.0))
160 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
161 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
162 | stop_str = params.get("stop", None)
163 | do_sample = True if temperature > 0.001 else False
164 |
165 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
166 | keywords = [stop_str]
167 | # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
168 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
169 |
170 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
171 |
172 | if max_new_tokens < 1:
173 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
174 | return
175 |
176 | thread = Thread(target=model.generate, kwargs=dict(
177 | inputs=input_ids,
178 | do_sample=do_sample,
179 | temperature=temperature,
180 | top_p=top_p,
181 | max_new_tokens=max_new_tokens,
182 | streamer=streamer,
183 | use_cache=True,
184 | **image_args
185 | ))
186 | thread.start()
187 |
188 | generated_text = ori_prompt
189 | for new_text in streamer:
190 | generated_text += new_text
191 | if generated_text.endswith(stop_str):
192 | generated_text = generated_text[:-len(stop_str)]
193 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
194 |
195 | def generate_stream_gate(self, params):
196 | try:
197 | for x in self.generate_stream(params):
198 | yield x
199 | except ValueError as e:
200 | print("Caught ValueError:", e)
201 | ret = {
202 | "text": server_error_msg,
203 | "error_code": 1,
204 | }
205 | yield json.dumps(ret).encode() + b"\0"
206 | except torch.cuda.CudaError as e:
207 | print("Caught torch.cuda.CudaError:", e)
208 | ret = {
209 | "text": server_error_msg,
210 | "error_code": 1,
211 | }
212 | yield json.dumps(ret).encode() + b"\0"
213 | except Exception as e:
214 | print("Caught Unknown Error", e)
215 | ret = {
216 | "text": server_error_msg,
217 | "error_code": 1,
218 | }
219 | yield json.dumps(ret).encode() + b"\0"
220 |
221 |
222 | app = FastAPI()
223 |
224 |
225 | def release_model_semaphore(fn=None):
226 | model_semaphore.release()
227 | if fn is not None:
228 | fn()
229 |
230 |
231 | @app.post("/worker_generate_stream")
232 | async def generate_stream(request: Request):
233 | global model_semaphore, global_counter
234 | global_counter += 1
235 | params = await request.json()
236 |
237 | if model_semaphore is None:
238 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
239 | await model_semaphore.acquire()
240 | worker.send_heart_beat()
241 | generator = worker.generate_stream_gate(params)
242 | background_tasks = BackgroundTasks()
243 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
244 | return StreamingResponse(generator, background=background_tasks)
245 |
246 |
247 | @app.post("/worker_get_status")
248 | async def get_status(request: Request):
249 | return worker.get_status()
250 |
251 |
252 | if __name__ == "__main__":
253 | parser = argparse.ArgumentParser()
254 | parser.add_argument("--host", type=str, default="localhost")
255 | parser.add_argument("--port", type=int, default=21002)
256 | parser.add_argument("--worker-address", type=str,
257 | default="http://localhost:21002")
258 | parser.add_argument("--controller-address", type=str,
259 | default="http://localhost:21001")
260 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
261 | parser.add_argument("--model-base", type=str, default=None)
262 | parser.add_argument("--model-name", type=str)
263 | parser.add_argument("--device", type=str, default="cuda")
264 | parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
265 | parser.add_argument("--limit-model-concurrency", type=int, default=5)
266 | parser.add_argument("--stream-interval", type=int, default=1)
267 | parser.add_argument("--no-register", action="store_true")
268 | parser.add_argument("--load-8bit", action="store_true")
269 | parser.add_argument("--load-4bit", action="store_true")
270 | parser.add_argument("--use-flash-attn", action="store_true")
271 | args = parser.parse_args()
272 | logger.info(f"args: {args}")
273 |
274 | if args.multi_modal:
275 | logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
276 |
277 | worker = ModelWorker(args.controller_address,
278 | args.worker_address,
279 | worker_id,
280 | args.no_register,
281 | args.model_path,
282 | args.model_base,
283 | args.model_name,
284 | args.load_8bit,
285 | args.load_4bit,
286 | args.device,
287 | use_flash_attn=args.use_flash_attn)
288 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
289 |
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/sglang_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | A model worker executes the model.
3 | """
4 | import argparse
5 | import asyncio
6 | from concurrent.futures import ThreadPoolExecutor
7 | import json
8 | import time
9 | import threading
10 | import uuid
11 |
12 | from fastapi import FastAPI, Request, BackgroundTasks
13 | from fastapi.responses import StreamingResponse
14 | import requests
15 | import re
16 | import uvicorn
17 | from functools import partial
18 |
19 | from llava.constants import WORKER_HEART_BEAT_INTERVAL
20 | from llava.utils import (build_logger, server_error_msg,
21 | pretty_print_semaphore)
22 | from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, expand2square
23 | from llava.constants import DEFAULT_IMAGE_TOKEN
24 |
25 | import sglang as sgl
26 | from sglang.backend.runtime_endpoint import RuntimeEndpoint
27 |
28 |
29 | GB = 1 << 30
30 |
31 | worker_id = str(uuid.uuid4())[:6]
32 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
33 | global_counter = 0
34 |
35 | model_semaphore = None
36 |
37 |
38 | def heart_beat_worker(controller):
39 | while True:
40 | time.sleep(WORKER_HEART_BEAT_INTERVAL)
41 | controller.send_heart_beat()
42 |
43 |
44 | @sgl.function
45 | def pipeline(s, prompt, max_tokens):
46 | for p in prompt:
47 | if type(p) is str:
48 | s += p
49 | else:
50 | s += sgl.image(p)
51 | s += sgl.gen("response", max_tokens=max_tokens)
52 |
53 |
54 | class ModelWorker:
55 | def __init__(self, controller_addr, worker_addr, sgl_endpoint,
56 | worker_id, no_register, model_name):
57 | self.controller_addr = controller_addr
58 | self.worker_addr = worker_addr
59 | self.worker_id = worker_id
60 |
61 | # Select backend
62 | backend = RuntimeEndpoint(sgl_endpoint)
63 | sgl.set_default_backend(backend)
64 | model_path = backend.model_info["model_path"]
65 |
66 | if model_path.endswith("/"):
67 | model_path = model_path[:-1]
68 | if model_name is None:
69 | model_paths = model_path.split("/")
70 | if model_paths[-1].startswith('checkpoint-'):
71 | self.model_name = model_paths[-2] + "_" + model_paths[-1]
72 | else:
73 | self.model_name = model_paths[-1]
74 | else:
75 | self.model_name = model_name
76 |
77 | logger.info(f"Loading the SGLANG model {self.model_name} on worker {worker_id} ...")
78 |
79 | if not no_register:
80 | self.register_to_controller()
81 | self.heart_beat_thread = threading.Thread(
82 | target=heart_beat_worker, args=(self,), daemon=True)
83 | self.heart_beat_thread.start()
84 |
85 | def register_to_controller(self):
86 | logger.info("Register to controller")
87 |
88 | url = self.controller_addr + "/register_worker"
89 | data = {
90 | "worker_name": self.worker_addr,
91 | "check_heart_beat": True,
92 | "worker_status": self.get_status()
93 | }
94 | r = requests.post(url, json=data)
95 | assert r.status_code == 200
96 |
97 | def send_heart_beat(self):
98 | logger.info(f"Send heart beat. Models: {[self.model_name]}. "
99 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
100 | f"global_counter: {global_counter}")
101 |
102 | url = self.controller_addr + "/receive_heart_beat"
103 |
104 | while True:
105 | try:
106 | ret = requests.post(url, json={
107 | "worker_name": self.worker_addr,
108 | "queue_length": self.get_queue_length()}, timeout=5)
109 | exist = ret.json()["exist"]
110 | break
111 | except requests.exceptions.RequestException as e:
112 | logger.error(f"heart beat error: {e}")
113 | time.sleep(5)
114 |
115 | if not exist:
116 | self.register_to_controller()
117 |
118 | def get_queue_length(self):
119 | if model_semaphore is None:
120 | return 0
121 | else:
122 | return args.limit_model_concurrency - model_semaphore._value + (len(
123 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
124 |
125 | def get_status(self):
126 | return {
127 | "model_names": [self.model_name],
128 | "speed": 1,
129 | "queue_length": self.get_queue_length(),
130 | }
131 |
132 | async def generate_stream(self, params):
133 | ori_prompt = prompt = params["prompt"]
134 | images = params.get("images", None)
135 | if images is not None and len(images) > 0:
136 | if len(images) > 0:
137 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
138 | raise ValueError("Number of images does not match number of tokens in prompt")
139 |
140 | images = [load_image_from_base64(image) for image in images]
141 |
142 | # FIXME: for image-start/end token
143 | # replace_token = DEFAULT_IMAGE_TOKEN
144 | # if getattr(self.model.config, 'mm_use_im_start_end', False):
145 | # replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
146 | # prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
147 | prompt = prompt.replace(' ' + DEFAULT_IMAGE_TOKEN + '\n', DEFAULT_IMAGE_TOKEN)
148 | prompt_split = prompt.split(DEFAULT_IMAGE_TOKEN)
149 | prompt = []
150 | for i in range(len(prompt_split)):
151 | prompt.append(prompt_split[i])
152 | if i < len(images):
153 | prompt.append(images[i])
154 | else:
155 | prompt = [prompt]
156 |
157 | temperature = float(params.get("temperature", 1.0))
158 | top_p = float(params.get("top_p", 1.0))
159 | # max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
160 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
161 | stop_str = params.get("stop", None)
162 | stop_str = [stop_str] if stop_str is not None else None
163 |
164 | print({'prompt': prompt, 'max_new_tokens': max_new_tokens, 'temperature': temperature, 'top_p': top_p})
165 | state = pipeline.run(prompt, max_new_tokens, temperature=temperature, top_p=top_p, stream=True)
166 |
167 | generated_text = ori_prompt
168 | async for text_outputs in state.text_async_iter(var_name="response"):
169 | generated_text += text_outputs
170 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
171 |
172 | async def generate_stream_gate(self, params):
173 | try:
174 | async for x in self.generate_stream(params):
175 | yield x
176 | except ValueError as e:
177 | print("Caught ValueError:", e)
178 | ret = {
179 | "text": server_error_msg,
180 | "error_code": 1,
181 | }
182 | yield json.dumps(ret).encode() + b"\0"
183 | except Exception as e:
184 | print("Caught Unknown Error", e)
185 | ret = {
186 | "text": server_error_msg,
187 | "error_code": 1,
188 | }
189 | yield json.dumps(ret).encode() + b"\0"
190 |
191 |
192 | app = FastAPI()
193 |
194 |
195 | def release_model_semaphore(fn=None):
196 | model_semaphore.release()
197 | if fn is not None:
198 | fn()
199 |
200 |
201 | @app.post("/worker_generate_stream")
202 | async def generate_stream(request: Request):
203 | global model_semaphore, global_counter
204 | global_counter += 1
205 | params = await request.json()
206 |
207 | if model_semaphore is None:
208 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
209 | await model_semaphore.acquire()
210 | worker.send_heart_beat()
211 | generator = worker.generate_stream_gate(params)
212 | background_tasks = BackgroundTasks()
213 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
214 | return StreamingResponse(generator, background=background_tasks)
215 |
216 |
217 | @app.post("/worker_get_status")
218 | async def get_status(request: Request):
219 | return worker.get_status()
220 |
221 |
222 | if __name__ == "__main__":
223 | parser = argparse.ArgumentParser()
224 | parser.add_argument("--host", type=str, default="localhost")
225 | parser.add_argument("--port", type=int, default=21002)
226 | parser.add_argument("--worker-address", type=str,
227 | default="http://localhost:21002")
228 | parser.add_argument("--controller-address", type=str,
229 | default="http://localhost:21001")
230 | parser.add_argument("--model-name", type=str)
231 | parser.add_argument("--sgl-endpoint", type=str)
232 | parser.add_argument("--limit-model-concurrency", type=int, default=5)
233 | parser.add_argument("--stream-interval", type=int, default=1)
234 | parser.add_argument("--no-register", action="store_true")
235 | args = parser.parse_args()
236 | logger.info(f"args: {args}")
237 |
238 | worker = ModelWorker(args.controller_address,
239 | args.worker_address,
240 | args.sgl_endpoint,
241 | worker_id,
242 | args.no_register,
243 | args.model_name)
244 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
245 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default="Tell me a story with more than 1000 words.")
59 | args = parser.parse_args()
60 |
61 | main()
62 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/llava/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import math
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import transformers.models.llama.modeling_llama
11 | from torch import nn
12 |
13 | try:
14 | import xformers.ops
15 | except ImportError:
16 | logging.error("xformers not found! Please install it before trying to use it.")
17 |
18 |
19 | def replace_llama_attn_with_xformers_attn():
20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21 |
22 |
23 | def xformers_forward(
24 | self,
25 | hidden_states: torch.Tensor,
26 | attention_mask: Optional[torch.Tensor] = None,
27 | position_ids: Optional[torch.LongTensor] = None,
28 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
29 | output_attentions: bool = False,
30 | use_cache: bool = False,
31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32 | # pylint: disable=duplicate-code
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 |
51 | kv_seq_len = key_states.shape[-2]
52 | if past_key_value is not None:
53 | kv_seq_len += past_key_value[0].shape[-2]
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | (
56 | query_states,
57 | key_states,
58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 |
63 | if past_key_value is not None:
64 | # reuse k, v, self_attention
65 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
66 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
67 |
68 | past_key_value = (key_states, value_states) if use_cache else None
69 |
70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
71 | if not output_attentions:
72 | query_states = query_states.transpose(1, 2)
73 | key_states = key_states.transpose(1, 2)
74 | value_states = value_states.transpose(1, 2)
75 |
76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
80 | attn_output = xformers.ops.memory_efficient_attention(
81 | query_states, key_states, value_states, attn_bias=None
82 | )
83 | else:
84 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
85 | attn_output = xformers.ops.memory_efficient_attention(
86 | query_states,
87 | key_states,
88 | value_states,
89 | attn_bias=xformers.ops.LowerTriangularMask(),
90 | )
91 | attn_weights = None
92 | else:
93 | attn_weights = torch.matmul(
94 | query_states, key_states.transpose(2, 3)
95 | ) / math.sqrt(self.head_dim)
96 |
97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98 | raise ValueError(
99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 | f" {attn_weights.size()}"
101 | )
102 |
103 | if attention_mask is not None:
104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 | raise ValueError(
106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 | )
108 | attn_weights = attn_weights + attention_mask
109 | attn_weights = torch.max(
110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 | )
112 |
113 | # upcast attention to fp32
114 | attn_weights = nn.functional.softmax(
115 | attn_weights, dim=-1, dtype=torch.float32
116 | ).to(query_states.dtype)
117 | attn_output = torch.matmul(attn_weights, value_states)
118 |
119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 | raise ValueError(
121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 | f" {attn_output.size()}"
123 | )
124 |
125 | attn_output = attn_output.transpose(1, 2)
126 |
127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 | attn_output = self.o_proj(attn_output)
129 | return attn_output, attn_weights, past_key_value
130 |
--------------------------------------------------------------------------------
/llava/train/llava_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch.utils.data import Sampler
6 |
7 | import transformers
8 | from transformers import Trainer
9 | from transformers.trainer import (
10 | is_sagemaker_mp_enabled,
11 | get_parameter_names,
12 | has_length,
13 | # ALL_LAYERNORM_LAYERS,
14 | logger,
15 | )
16 | from typing import List, Optional
17 |
18 |
19 | ALL_LAYERNORM_LAYERS = [nn.LayerNorm, nn.BatchNorm2d]
20 |
21 |
22 | def maybe_zero_3(param, ignore_status=False, name=None):
23 | from deepspeed import zero
24 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
25 | if hasattr(param, "ds_id"):
26 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
27 | if not ignore_status:
28 | print(name, 'no ignore status')
29 | with zero.GatheredParameters([param]):
30 | param = param.data.detach().cpu().clone()
31 | else:
32 | param = param.detach().cpu().clone()
33 | return param
34 |
35 |
36 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
37 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
38 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
39 | return to_return
40 |
41 |
42 | def split_to_even_chunks(indices, lengths, num_chunks):
43 | """
44 | Split a list of indices into `chunks` chunks of roughly equal lengths.
45 | """
46 |
47 | if len(indices) % num_chunks != 0:
48 | return [indices[i::num_chunks] for i in range(num_chunks)]
49 |
50 | num_indices_per_chunk = len(indices) // num_chunks
51 |
52 | chunks = [[] for _ in range(num_chunks)]
53 | chunks_lengths = [0 for _ in range(num_chunks)]
54 | for index in indices:
55 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
56 | chunks[shortest_chunk].append(index)
57 | chunks_lengths[shortest_chunk] += lengths[index]
58 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
59 | chunks_lengths[shortest_chunk] = float("inf")
60 |
61 | return chunks
62 |
63 |
64 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
65 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
66 | assert all(l != 0 for l in lengths), "Should not have zero length."
67 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
68 | # all samples are in the same modality
69 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
70 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
71 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
72 |
73 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
74 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
75 | megabatch_size = world_size * batch_size
76 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
77 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
78 |
79 | last_mm = mm_megabatches[-1]
80 | last_lang = lang_megabatches[-1]
81 | additional_batch = last_mm + last_lang
82 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
83 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
84 | megabatches = [megabatches[i] for i in megabatch_indices]
85 |
86 | if len(additional_batch) > 0:
87 | megabatches.append(sorted(additional_batch))
88 |
89 | return [i for megabatch in megabatches for i in megabatch]
90 |
91 |
92 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
93 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
94 | indices = torch.randperm(len(lengths), generator=generator)
95 | megabatch_size = world_size * batch_size
96 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
97 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
98 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
99 |
100 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
101 |
102 |
103 | class LengthGroupedSampler(Sampler):
104 | r"""
105 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
106 | keeping a bit of randomness.
107 | """
108 |
109 | def __init__(
110 | self,
111 | batch_size: int,
112 | world_size: int,
113 | lengths: Optional[List[int]] = None,
114 | generator=None,
115 | group_by_modality: bool = False,
116 | ):
117 | if lengths is None:
118 | raise ValueError("Lengths must be provided.")
119 |
120 | self.batch_size = batch_size
121 | self.world_size = world_size
122 | self.lengths = lengths
123 | self.generator = generator
124 | self.group_by_modality = group_by_modality
125 |
126 | def __len__(self):
127 | return len(self.lengths)
128 |
129 | def __iter__(self):
130 | if self.group_by_modality:
131 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
132 | else:
133 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
134 | return iter(indices)
135 |
136 |
137 | class LLaVATrainer(Trainer):
138 |
139 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
140 | if self.train_dataset is None or not has_length(self.train_dataset):
141 | return None
142 |
143 | if self.args.group_by_modality_length:
144 | lengths = self.train_dataset.modality_lengths
145 | return LengthGroupedSampler(
146 | self.args.train_batch_size,
147 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
148 | lengths=lengths,
149 | group_by_modality=True,
150 | )
151 | else:
152 | return super()._get_train_sampler()
153 |
154 | def create_optimizer(self):
155 | """
156 | Setup the optimizer.
157 |
158 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
159 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
160 | """
161 | if is_sagemaker_mp_enabled():
162 | return super().create_optimizer()
163 |
164 | opt_model = self.model
165 |
166 | if self.optimizer is None:
167 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
168 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
169 |
170 | lr_mapper = {}
171 | if self.args.mm_projector_lr is not None:
172 | lr_mapper["mm_projector"] = self.args.mm_projector_lr
173 | if self.args.mm_vision_tower_lr is not None:
174 | lr_mapper["vision_tower"] = self.args.mm_vision_tower_lr
175 |
176 | if len(lr_mapper) > 0:
177 | special_lr_parameters = [name for name, _ in opt_model.named_parameters() if
178 | any(module_keyword in name for module_keyword in lr_mapper)]
179 | optimizer_grouped_parameters = [
180 | {
181 | "params": [p for n, p in opt_model.named_parameters() if
182 | (n in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
183 | "weight_decay": self.args.weight_decay,
184 | },
185 | {
186 | "params": [p for n, p in opt_model.named_parameters() if
187 | (n not in decay_parameters and n not in special_lr_parameters and p.requires_grad)],
188 | "weight_decay": 0.0,
189 | },
190 | ]
191 | for module_keyword, lr in lr_mapper.items():
192 | module_parameters = [name for name, _ in opt_model.named_parameters() if module_keyword in name]
193 | optimizer_grouped_parameters.extend(
194 | [
195 | {
196 | "params": [p for n, p in opt_model.named_parameters() if
197 | (n in decay_parameters and n in module_parameters and p.requires_grad)],
198 | "weight_decay": self.args.weight_decay,
199 | "lr": lr,
200 | },
201 | {
202 | "params": [p for n, p in opt_model.named_parameters() if
203 | (n not in decay_parameters and n in module_parameters and p.requires_grad)],
204 | "weight_decay": 0.0,
205 | "lr": lr,
206 | },
207 | ]
208 | )
209 | else:
210 | optimizer_grouped_parameters = [
211 | {
212 | "params": [
213 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
214 | ],
215 | "weight_decay": self.args.weight_decay,
216 | },
217 | {
218 | "params": [
219 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
220 | ],
221 | "weight_decay": 0.0,
222 | },
223 | ]
224 |
225 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
226 |
227 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
228 | if optimizer_cls.__name__ == "Adam8bit":
229 | import bitsandbytes
230 |
231 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
232 |
233 | skipped = 0
234 | for module in opt_model.modules():
235 | if isinstance(module, nn.Embedding):
236 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
237 | logger.info(f"skipped {module}: {skipped/2**20}M params")
238 | manager.register_module_override(module, "weight", {"optim_bits": 32})
239 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
240 | logger.info(f"skipped: {skipped/2**20}M params")
241 |
242 | return self.optimizer
243 |
244 | def _save_checkpoint(self, model, trial, metrics=None):
245 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
246 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
247 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
248 |
249 | run_dir = self._get_output_dir(trial=trial)
250 | output_dir = os.path.join(run_dir, checkpoint_folder)
251 |
252 | # Only save Adapter
253 | keys_to_match = ['mm_projector', 'vision_resampler']
254 | if getattr(self.args, "use_im_start_end", False):
255 | keys_to_match.extend(['embed_tokens', 'embed_in'])
256 |
257 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
258 |
259 | if self.args.local_rank == 0 or self.args.local_rank == -1:
260 | self.model.config.save_pretrained(output_dir)
261 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
262 | else:
263 | # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144
264 | model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None)
265 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics)
266 |
267 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
268 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
269 | pass
270 | else:
271 | # Workaround for the issue: https://github.com/haotian-liu/LLaVA/issues/1144
272 | self.model.generation_config = transformers.GenerationConfig(do_sample=True, temperature=None, top_p=None)
273 | super(LLaVATrainer, self)._save(output_dir, state_dict)
274 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | from llava.train.train_qwen import train
2 |
3 | if __name__ == "__main__":
4 | train(attn_implementation="flash_attention_2")
5 |
--------------------------------------------------------------------------------
/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from llava.train.train import train
5 | from llava.train.llama_xformers_attn_monkey_patch import (
6 | replace_llama_attn_with_xformers_attn,
7 | )
8 |
9 | replace_llama_attn_with_xformers_attn()
10 |
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 |
65 | def __init__(self, logger, log_level=logging.INFO):
66 | self.terminal = sys.stdout
67 | self.logger = logger
68 | self.log_level = log_level
69 | self.linebuf = ''
70 |
71 | def __getattr__(self, attr):
72 | return getattr(self.terminal, attr)
73 |
74 | def write(self, buf):
75 | temp_linebuf = self.linebuf + buf
76 | self.linebuf = ''
77 | for line in temp_linebuf.splitlines(True):
78 | # From the io.TextIOWrapper docs:
79 | # On output, if newline is None, any '\n' characters written
80 | # are translated to the system default line separator.
81 | # By default sys.stdout.write() expects '\n' newlines and then
82 | # translates them so this is still cross platform.
83 | if line[-1] == '\n':
84 | self.logger.log(self.log_level, line.rstrip())
85 | else:
86 | self.linebuf += line
87 |
88 | def flush(self):
89 | if self.linebuf != '':
90 | self.logger.log(self.log_level, self.linebuf.rstrip())
91 | self.linebuf = ''
92 |
93 |
94 | def disable_torch_init():
95 | """
96 | Disable the redundant torch default initialization to accelerate model creation.
97 | """
98 | import torch
99 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
100 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
101 |
102 |
103 | def violates_moderation(text):
104 | """
105 | Check whether the text violates OpenAI moderation API.
106 | """
107 | url = "https://api.openai.com/v1/moderations"
108 | headers = {"Content-Type": "application/json",
109 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
110 | text = text.replace("\n", "")
111 | data = "{" + '"input": ' + f'"{text}"' + "}"
112 | data = data.encode("utf-8")
113 | try:
114 | ret = requests.post(url, headers=headers, data=data, timeout=5)
115 | flagged = ret.json()["results"][0]["flagged"]
116 | except requests.exceptions.RequestException as e:
117 | flagged = False
118 | except KeyError as e:
119 | flagged = False
120 |
121 | return flagged
122 |
123 |
124 | def pretty_print_semaphore(semaphore):
125 | if semaphore is None:
126 | return "None"
127 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
128 |
--------------------------------------------------------------------------------
/model_export/README.md:
--------------------------------------------------------------------------------
1 | # Model Export for inference on Apple Silicon
2 | Disclaimer: this is not an official recommendation, just research and exploration.
3 |
4 | ## Export Vision Encoder
5 | We found that LLaVA trainer does not save all the states needed for auto inference,
6 | predominantly used in third party libraries like `mlx-vlm`. We save additional metadata
7 | to model checkpoint directory and export the vision model using coremltools.
8 | Export vision encoder and patch the checkpoint using the instruction below.
9 | ```bash
10 | python export_vision_encoder.py --model-path /path/to/fastvlm-checkpoint
11 | ```
12 |
13 | ## Export VLM
14 |
15 | ### Install mlx-vlm
16 | We provide a patch to `mlx-vlm` to support inference of FastVLM.
17 | ```bash
18 | git clone https://github.com/Blaizzy/mlx-vlm.git
19 | cd mlx-vlm
20 | git checkout 1884b551bc741f26b2d54d68fa89d4e934b9a3de
21 | git apply ../fastvlm_mlx-vlm.patch
22 | pip install -e .
23 | ```
24 |
25 | Export model using the following instruction.
26 | ```bash
27 | python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \
28 | --mlx-path /path/to/exported-fastvlm \
29 | --only-llm
30 | ```
31 | To quantize the LLM, additional options can be provided as shown below.
32 | `--q-bits` specifies bits per weight, the command below exports the LLM with 8-bit quantization.
33 | ```bash
34 | python -m mlx_vlm.convert --hf-path /path/to/fastvlm-checkpoint \
35 | --mlx-path /path/to/exported-fastvlm \
36 | --only-llm \
37 | -q \
38 | --q-bits 8 # For 4-bit quantization, specify 4
39 | ```
40 |
41 | ### Generate
42 | The exported model can be used for inference in a python environment following the instruction below.
43 | ```bash
44 | python -m mlx_vlm.generate --model /path/to/exported-fastvlm \
45 | --image /path/to/image.png \
46 | --prompt "Describe the image." \
47 | --max-tokens 256 \
48 | --temp 0.0
49 | ```
50 |
51 | ## Troubleshooting
52 | We noticed that sometimes `config.json` for the LLaVA model incorrectly sets the value for `tie_word_embeddings`.
53 | This causes the following error during conversion, `ValueError: Received parameters not in model: language_model.lm_head.weight.`
54 | If you encounter this error, set the value of `tie_word_embeddings` accordingly.
55 |
--------------------------------------------------------------------------------
/model_export/export_vision_encoder.py:
--------------------------------------------------------------------------------
1 | #
2 | # For licensing see accompanying LICENSE file.
3 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
4 | #
5 | import os
6 | import json
7 | import copy
8 | import argparse
9 |
10 | import torch
11 | import numpy as np
12 | import coremltools
13 |
14 | from llava.model.builder import load_pretrained_model
15 | from llava.utils import disable_torch_init
16 | from llava.mm_utils import get_model_name_from_path
17 |
18 |
19 | def export(args):
20 | # Load model
21 | disable_torch_init()
22 | model_path = os.path.expanduser(args.model_path)
23 | model_name = get_model_name_from_path(model_path)
24 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path,
25 | args.model_base,
26 | model_name,
27 | device="mps")
28 |
29 | # Save extra metadata that is not saved during LLaVA training
30 | # required by HF for auto-loading model and for mlx-vlm preprocessing
31 |
32 | # Save image processing config
33 | setattr(image_processor, "processor_class", "LlavaProcessor")
34 | output_path = os.path.join(model_path, "preprocessor_config.json")
35 | image_processor.to_json_file(output_path)
36 |
37 | # Create processor config
38 | processor_config = dict()
39 | processor_config["image_token"] = ""
40 | processor_config["num_additional_image_tokens"] = 0
41 | processor_config["processor_class"] = "LlavaProcessor"
42 | processor_config["patch_size"] = 64
43 | output_path = os.path.join(model_path, "processor_config.json")
44 | json.dump(processor_config, open(output_path, "w"), indent=2)
45 |
46 | # Modify tokenizer to include special token.
47 | tokenizer_config_path = os.path.join(model_path, "tokenizer_config.json")
48 | tokenizer_config = json.load(open(tokenizer_config_path, 'r'))
49 | token_ids = list()
50 | image_token_is_present = False
51 | for k, v in tokenizer_config['added_tokens_decoder'].items():
52 | token_ids.append(int(k))
53 | if v["content"] == "":
54 | image_token_is_present = True
55 | token_ids.pop()
56 |
57 | # Append only if token is not present
58 | if not image_token_is_present:
59 | tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}'] = copy.deepcopy(
60 | tokenizer_config['added_tokens_decoder'][f'{token_ids[0]}'])
61 | tokenizer_config['added_tokens_decoder'][f'{max(token_ids) + 1}']["content"] = ""
62 | json.dump(tokenizer_config, open(tokenizer_config_path, 'w'), indent=2)
63 |
64 | # Modify config to contain token id for
65 | config_path = os.path.join(model_path, "config.json")
66 | model_config = json.load(open(config_path, 'r'))
67 | model_config["image_token_index"] = max(token_ids) + 1
68 | json.dump(model_config, open(config_path, 'w'), indent=2)
69 |
70 | # Export the vision encoder to CoreML
71 | image_res = image_processor.to_dict()['size']['shortest_edge']
72 | inputs = torch.rand(1, 3, image_res, image_res)
73 | inputs_tensor = [
74 | coremltools.TensorType(
75 | name="images",
76 | shape=inputs.shape,
77 | )
78 | ]
79 | vision_model = model.get_vision_tower()
80 | vision_model = vision_model.float()
81 | traced_model = torch.jit.trace(vision_model, torch.Tensor(inputs))
82 | pt_name = "fastvithd.pt"
83 | traced_model.save(pt_name)
84 |
85 | # Export
86 | ml_model = coremltools.convert(
87 | model=pt_name,
88 | outputs=[coremltools.TensorType(name="image_features", dtype=np.float32)],
89 | inputs=inputs_tensor,
90 | convert_to="mlprogram",
91 | debug=False,
92 | compute_units=coremltools.ComputeUnit.CPU_AND_GPU,
93 | minimum_deployment_target=coremltools.target.iOS16,
94 | compute_precision=coremltools.precision.FLOAT32
95 | )
96 | ml_model_path = os.path.join(model_path, "fastvithd.mlpackage")
97 | ml_model.save(ml_model_path)
98 |
99 | # Remove traced model
100 | os.remove(pt_name)
101 |
102 |
103 | if __name__ == "__main__":
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument("--model-path", type=str, required=True)
106 | parser.add_argument("--model-base", type=str, default=None)
107 | parser.add_argument("--conv-mode", type=str, default="qwen_2")
108 |
109 | args = parser.parse_args()
110 |
111 | export(args)
112 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | #
2 | # Modified from LLaVA/predict.py
3 | # Please see ACKNOWLEDGEMENTS for details about LICENSE
4 | #
5 | import os
6 | import argparse
7 |
8 | import torch
9 | from PIL import Image
10 |
11 | from llava.utils import disable_torch_init
12 | from llava.conversation import conv_templates
13 | from llava.model.builder import load_pretrained_model
14 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
15 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
16 |
17 |
18 | def predict(args):
19 | # Remove generation config from model folder
20 | # to read generation parameters from args
21 | model_path = os.path.expanduser(args.model_path)
22 | generation_config = None
23 | if os.path.exists(os.path.join(model_path, 'generation_config.json')):
24 | generation_config = os.path.join(model_path, '.generation_config.json')
25 | os.rename(os.path.join(model_path, 'generation_config.json'),
26 | generation_config)
27 |
28 | # Load model
29 | disable_torch_init()
30 | model_name = get_model_name_from_path(model_path)
31 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, device="mps")
32 |
33 | # Construct prompt
34 | qs = args.prompt
35 | if model.config.mm_use_im_start_end:
36 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
37 | else:
38 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
39 | conv = conv_templates[args.conv_mode].copy()
40 | conv.append_message(conv.roles[0], qs)
41 | conv.append_message(conv.roles[1], None)
42 | prompt = conv.get_prompt()
43 |
44 | # Set the pad token id for generation
45 | model.generation_config.pad_token_id = tokenizer.pad_token_id
46 |
47 | # Tokenize prompt
48 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(torch.device("mps"))
49 |
50 | # Load and preprocess image
51 | image = Image.open(args.image_file).convert('RGB')
52 | image_tensor = process_images([image], image_processor, model.config)[0]
53 |
54 | # Run inference
55 | with torch.inference_mode():
56 | output_ids = model.generate(
57 | input_ids,
58 | images=image_tensor.unsqueeze(0).half(),
59 | image_sizes=[image.size],
60 | do_sample=True if args.temperature > 0 else False,
61 | temperature=args.temperature,
62 | top_p=args.top_p,
63 | num_beams=args.num_beams,
64 | max_new_tokens=256,
65 | use_cache=True)
66 |
67 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
68 | print(outputs)
69 |
70 | # Restore generation config
71 | if generation_config is not None:
72 | os.rename(generation_config, os.path.join(model_path, 'generation_config.json'))
73 |
74 |
75 | if __name__ == "__main__":
76 | parser = argparse.ArgumentParser()
77 | parser.add_argument("--model-path", type=str, default="./llava-v1.5-13b")
78 | parser.add_argument("--model-base", type=str, default=None)
79 | parser.add_argument("--image-file", type=str, default=None, help="location of image file")
80 | parser.add_argument("--prompt", type=str, default="Describe the image.", help="Prompt for VLM.")
81 | parser.add_argument("--conv-mode", type=str, default="qwen_2")
82 | parser.add_argument("--temperature", type=float, default=0.2)
83 | parser.add_argument("--top_p", type=float, default=None)
84 | parser.add_argument("--num_beams", type=int, default=1)
85 | args = parser.parse_args()
86 |
87 | predict(args)
88 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.2.2.post1"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.6.0", "torchvision==0.21.0",
17 | "transformers==4.48.3", "tokenizers==0.21.0", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==1.6.0", "peft>=0.10.0,<0.14.0", "bitsandbytes",
19 | "pydantic", "markdown2[all]", "numpy==1.26.4", "scikit-learn==1.2.2",
20 | "gradio==5.11.0", "requests", "uvicorn", "fastapi",
21 | "einops==0.6.1", "einops-exts==0.0.4", "timm==1.0.15",
22 | "coremltools==8.2"
23 | ]
24 |
25 | [project.optional-dependencies]
26 | train = ["deepspeed==0.13.1", "ninja", "wandb"]
27 | build = ["build", "twine"]
28 |
29 | [tool.setuptools.packages.find]
30 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
31 |
32 | [tool.wheel]
33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
34 |
--------------------------------------------------------------------------------